Skip to content

Commit 27eada9

Browse files
Merge pull request #3052 from AI-Hypercomputer:hengtaoguo-utils2
PiperOrigin-RevId: 866156329
2 parents 9fe6cdf + 4dfcb23 commit 27eada9

18 files changed

Lines changed: 1644 additions & 1569 deletions

benchmarks/api_server/maxtext_generator.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535
from dataclasses import dataclass, field
3636

37-
from MaxText import maxengine, pyconfig, multimodal_utils
37+
from MaxText import maxengine, pyconfig
38+
from MaxText.multimodal import processor as mm_processor
39+
from MaxText.multimodal import utils as mm_utils
3840
from maxtext.utils import max_logging, max_utils
3941

4042
# Set TF log level to avoid verbose startup messages.
@@ -493,23 +495,25 @@ def _build_completions(self, streams, logprobs, echo):
493495

494496
def _preprocess_inputs(self, text, prefill_length, image_path):
495497
"""Helper to preprocess a single text and optional image input."""
496-
processor_output = multimodal_utils.PreprocessorOutput()
498+
processor_output = mm_utils.PreprocessorOutput()
497499
images = None
498500
if self.config.use_multimodal and image_path:
499-
text = multimodal_utils.reformat_prompt(
500-
text, image_placeholder=self.config.image_placeholder, model_name=self.config.model_name, num_images=1
501+
text = mm_processor.reformat_prompt(
502+
prompt=self.config.prompt,
503+
image_placeholder=self.config.image_placeholder,
504+
model_name=self.config.model_name,
505+
num_images=1,
501506
)
502-
loaded_images = multimodal_utils.load_image_from_path(image_path)
503-
processor_output = multimodal_utils.pre_process_image(loaded_images, model_name=self.config.model_name)
504-
prefill_length -= multimodal_utils.get_image_offsets(self.config.model_name, processor_output=processor_output)
507+
processor_output = mm_processor.preprocess_mm_data(self.config)
508+
prefill_length -= mm_processor.get_image_offsets(self.config.model_name, processor_output=processor_output)
505509
images = processor_output.pixel_values
506510

507511
tokens, true_length = self.tokenizer.encode(text, is_bos=not self.has_chat_template, prefill_lengths=[prefill_length])
508512
if self.config.use_multimodal and image_path:
509-
tokens = multimodal_utils.prepare_text_for_image_fusion(
513+
tokens = mm_processor.prepare_text_for_image_fusion(
510514
tokens, model_name=self.config.model_name, processor_output=processor_output
511515
)
512-
true_length += multimodal_utils.get_image_offsets(self.config.model_name, processor_output=processor_output)
516+
true_length += mm_processor.get_image_offsets(self.config.model_name, processor_output=processor_output)
513517

514518
return tokens, true_length, images
515519

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
import numpy as np
2525
import tensorflow as tf
2626
from MaxText import tokenizer
27-
from MaxText import multimodal_utils
27+
from MaxText.multimodal import processor as mm_processor
28+
from MaxText.multimodal import utils as mm_utils
2829
from maxtext.utils import max_logging
2930

3031
Features = dict[str, tf.Tensor]
@@ -73,13 +74,13 @@ def reformat_prompt(example, column, image_placeholder, model_name):
7374
num_images = len(example["images"])
7475
else:
7576
num_images = 1
76-
example[column] = multimodal_utils.reformat_prompt(example[column], image_placeholder, model_name, num_images)
77+
example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images)
7778
return example
7879

7980

8081
def reformat_response(example, column, model_name):
8182
"""reformat response for multimodal SFT"""
82-
example[column] = multimodal_utils.reformat_response(example[column][0], model_name)
83+
example[column] = mm_processor.reformat_response(example[column][0], model_name)
8384
return example
8485

8586

@@ -101,11 +102,11 @@ def pre_process_image_sft(example, image_column, model_name):
101102

102103
def _process_image_fn(image):
103104
if isinstance(image, list):
104-
image = [np.array(multimodal_utils.convert_to_RGB(img)) for img in image]
105+
image = [np.array(mm_utils.convert_to_RGB(img)) for img in image]
105106
else:
106-
image = np.array(multimodal_utils.convert_to_RGB(image))
107+
image = np.array(mm_utils.convert_to_RGB(image))
107108

108-
image = multimodal_utils.pre_process_image(image, model_name)
109+
image = mm_processor.preprocess_image_for_training(image, model_name)
109110
return image
110111

111112
example[image_column] = _process_image_fn(example[image_column])
@@ -114,7 +115,7 @@ def _process_image_fn(image):
114115

115116
def prepare_text_for_image_fusion(example, column_name, model_name):
116117
"""prepare text for image fusion for multimodal SFT"""
117-
example[column_name] = multimodal_utils.prepare_text_for_image_fusion(
118+
example[column_name] = mm_processor.prepare_text_for_image_fusion(
118119
example[column_name], model_name, processor_output=example["images"]
119120
)
120121
return example
@@ -478,9 +479,7 @@ def _pad_text(self, x: np.ndarray, max_length: int, pad_id: int) -> np.ndarray:
478479
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
479480
return np.pad(x, pad_amount, constant_values=pad_id)[: self.max_length]
480481

481-
def _pad_image_and_mask(
482-
self, preprocessed_image: multimodal_utils.PreprocessorOutput
483-
) -> multimodal_utils.PreprocessorOutput:
482+
def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -> mm_utils.PreprocessorOutput:
484483
"""Pads the input tensors (image and mask) of a PreprocessorOutput to a maximum number of items.
485484
486485
This function unifies padding logic for image tensors (standard or tiled) and
@@ -513,14 +512,14 @@ def _pad_image_and_mask(
513512
- The dummy images used for padding are based on the image shape for initialization
514513
of this model (ignoring batch size).
515514
"""
516-
if not isinstance(preprocessed_image, multimodal_utils.PreprocessorOutput):
515+
if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput):
517516
raise TypeError(f"Input must be multimodal_utils.PreprocessorOutput, but got {type(preprocessed_image)}")
518517

519518
if preprocessed_image.pixel_values is None:
520519
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")
521520

522521
# Determine the maximum number of images/masks allowed.
523-
image_offsets = multimodal_utils.get_image_offsets(self.model_name, preprocessed_image)
522+
image_offsets = mm_processor.get_image_offsets(self.model_name, preprocessed_image)
524523
single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0]
525524

526525
# Reserve space for at least one text token.
@@ -569,13 +568,13 @@ def _pad(tensor: np.ndarray) -> np.ndarray:
569568
return preprocessed_image
570569

571570
def map(
572-
self, element: dict[str, np.ndarray | multimodal_utils.PreprocessorOutput]
573-
) -> dict[str, np.ndarray | multimodal_utils.PreprocessorOutput]:
571+
self, element: dict[str, np.ndarray | mm_utils.PreprocessorOutput]
572+
) -> dict[str, np.ndarray | mm_utils.PreprocessorOutput]:
574573
"""map to each element"""
575574
data_columns = list(element.keys())
576575
for data_column in data_columns:
577576
if data_column != "images":
578-
if isinstance(element[data_column], multimodal_utils.PreprocessorOutput):
577+
if isinstance(element[data_column], mm_utils.PreprocessorOutput):
579578
raise TypeError("Only 'images' column can be of type PreprocessorOutput.")
580579

581580
element[f"{data_column}_segmentation"] = element[data_column] != self.pad_id
@@ -615,7 +614,7 @@ def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
615614
if preprocessed_image is None:
616615
return element
617616

618-
if not isinstance(preprocessed_image, multimodal_utils.PreprocessorOutput):
617+
if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput):
619618
raise TypeError(f"'images' must be of type PreprocessorOutput, but got {type(preprocessed_image)}")
620619

621620
output = element.copy()
@@ -646,7 +645,7 @@ class FoldImagesIntoBatch(grain.MapTransform):
646645

647646
def __post_init__(self):
648647
"""Initializes the target shape after the dataclass is created."""
649-
self.target_shape = multimodal_utils.get_dummy_image_shape_for_init(self.model_name)
648+
self.target_shape = mm_processor.get_dummy_image_shape_for_init(self.model_name)
650649

651650
def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
652651
"""Applies the folding transformation to the 'images' field if present."""
@@ -777,7 +776,10 @@ def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
777776
second_per_grids = element.get("second_per_grids")
778777

779778
# Call the standalone get_rope_index function from multimodal_utils
780-
position_ids, mrope_position_deltas = multimodal_utils.get_rope_index(
779+
from MaxText.multimodal import processor_qwen3_omni # pylint: disable=import-outside-toplevel
780+
781+
# TODO(jfacevedo/hengtaoguo): Now get_rope_index is Qwen3-Omni specific. We should generalize it for other models
782+
position_ids, mrope_position_deltas = processor_qwen3_omni.get_rope_index(
781783
input_ids=input_ids,
782784
image_grid_thw=image_grid_thw,
783785
video_grid_thw=video_grid_thw,

src/MaxText/layers/decoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from MaxText.layers import normalizations
3636
from MaxText.layers import quantizations
3737
from MaxText.layers import pipeline
38-
from MaxText import multimodal_utils
3938
from MaxText import sharding
4039
from MaxText.layers.attentions import attention_as_linen
4140
from MaxText.layers.normalizations import rms_norm
@@ -57,6 +56,7 @@
5756
olmo3,
5857
)
5958
from maxtext.inference import page_manager
59+
from MaxText.multimodal import utils as mm_utils
6060
from maxtext.utils import max_logging
6161
from maxtext.utils import max_utils
6262
from maxtext.utils import maxtext_utils
@@ -587,7 +587,7 @@ def _apply_embedding(
587587
"llama4-17b-128e",
588588
"qwen3-omni-30b-a3b",
589589
]:
590-
y = multimodal_utils.merge_mm_embeddings(
590+
y = mm_utils.merge_mm_embeddings(
591591
text_embeddings=y,
592592
multimodal_embeddings=image_embeddings,
593593
mask=bidirectional_mask,
@@ -599,7 +599,7 @@ def _apply_embedding(
599599

600600
if audio_embeddings is not None and cfg.use_audio:
601601
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
602-
y = multimodal_utils.merge_mm_embeddings(
602+
y = mm_utils.merge_mm_embeddings(
603603
text_embeddings=y,
604604
multimodal_embeddings=audio_embeddings,
605605
mask=audio_masks,

src/MaxText/layers/models.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@
2525
from flax import nnx
2626
from MaxText.layers import initializers
2727

28-
from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR
29-
from MaxText import multimodal_utils
28+
from MaxText.common_types import Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR
3029
from MaxText.layers import nnx_wrappers
3130
from MaxText.layers.decoders import Decoder
3231
from MaxText.layers.embeddings import Embed, embed_as_linen
3332
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen, AudioEncoder, audio_encoder_as_linen
3433
from MaxText.layers.quantizations import AqtQuantization as Quant
3534
from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen
3635
from maxtext.inference import page_manager
36+
from MaxText.multimodal import processor as mm_processor
3737
from maxtext.utils import max_utils
3838

3939
# ------------------------------------------------------------------------------
@@ -155,24 +155,15 @@ def __call__(
155155

156156
if self.config.use_multimodal and encoder_images is not None:
157157
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
158+
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
158159

159-
if self.config.decoder_block == DecoderBlockType.GEMMA3:
160-
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER
161-
elif self.config.decoder_block == DecoderBlockType.LLAMA4:
162-
bidirectional_mask = decoder_input_tokens == multimodal_utils.LLAMA4_PATCH_TOKEN
163-
elif self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
164-
# Create bidirectional_mask for vision/video token merging
165-
bidirectional_mask = (decoder_input_tokens == multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN) | (
166-
decoder_input_tokens == multimodal_utils.QWEN3_OMNI_VIDEO_TOKEN
167-
)
168-
# Create image/video mask for deepstack visual embedding injection
169160
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
170161
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)
171162

172163
# Create audio mask for placeholder tokens (qwen3-omni models)
173164
audio_masks = None
174-
if audio_embeddings is not None and self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
175-
audio_masks = decoder_input_tokens == multimodal_utils.QWEN3_OMNI_AUDIO_TOKEN
165+
if audio_embeddings is not None:
166+
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
176167

177168
logits, hidden_state, kv_caches = self.decoder(
178169
shared_embedding=self.shared_embedding,
@@ -469,24 +460,16 @@ def __call__(
469460
image_embeddings = None
470461
if self.config.use_multimodal and encoder_images is not None:
471462
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
472-
473-
if self.config.decoder_block == DecoderBlockType.GEMMA3:
474-
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER
475-
elif self.config.decoder_block == DecoderBlockType.LLAMA4:
476-
bidirectional_mask = decoder_input_tokens == multimodal_utils.LLAMA4_PATCH_TOKEN
477-
elif self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
478-
bidirectional_mask = (decoder_input_tokens == multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN) | (
479-
decoder_input_tokens == multimodal_utils.QWEN3_OMNI_VIDEO_TOKEN
480-
)
463+
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
481464

482465
audio_embeddings = None
483466
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
484467
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)
485468

486469
# Create audio mask for placeholder tokens (qwen3-omni models)
487470
audio_masks = None
488-
if audio_embeddings is not None and self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
489-
audio_masks = decoder_input_tokens == multimodal_utils.QWEN3_OMNI_AUDIO_TOKEN
471+
if audio_embeddings is not None:
472+
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
490473

491474
logits, hidden_state, kv_caches = self.decoder(
492475
shared_embedding=self.token_embedder,

src/MaxText/maxengine.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636
from flax.linen import partitioning as nn_partitioning
3737
import flax
3838

39-
from MaxText import multimodal_utils
4039
from MaxText import pyconfig
4140
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
4241
from MaxText.globals import MAXTEXT_PKG_DIR
4342
from MaxText.layers import models, quantizations
4443
from maxtext.inference import inference_utils
4544
from maxtext.inference.page_manager import PageManager, PageState
45+
from MaxText.multimodal import processor as mm_processor
4646
from maxtext.utils import lora_utils
4747
from maxtext.utils import max_utils
4848
from maxtext.utils import maxtext_utils
@@ -325,12 +325,11 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None):
325325

326326
@jax.jit
327327
def model_apply(_p, _rng):
328-
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
329-
self.config.model_name, batch_size=self.config.micro_batch_size_to_train_on
330-
)
331-
audio_shape = multimodal_utils.get_dummy_audio_shape_for_init(
332-
self.config.model_name, config=self.config, batch_size=self.config.micro_batch_size_to_train_on
328+
image_shape = mm_processor.get_dummy_image_shape_for_init(
329+
model_name=self.config.model_name,
330+
batch_size=self.config.micro_batch_size_to_train_on,
333331
)
332+
audio_shape = mm_processor.get_dummy_audio_shape_for_init(self.config)
334333
return self.model.apply(
335334
_p | {"aqt": {}},
336335
jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32),
@@ -1584,15 +1583,13 @@ def init(abstract_params, page_state):
15841583
dtype=jnp.int32,
15851584
)
15861585
dummy_image = jnp.ones(
1587-
multimodal_utils.get_dummy_image_shape_for_init(
1588-
self.config.model_name, batch_size=self.config.micro_batch_size_to_train_on
1586+
mm_processor.get_dummy_image_shape_for_init(
1587+
model_name=self.config.model_name, batch_size=self.config.per_device_batch_size
15891588
),
15901589
dtype=jnp.int32,
15911590
)
15921591
dummy_audio = jnp.ones(
1593-
multimodal_utils.get_dummy_audio_shape_for_init(
1594-
self.config.model_name, config=self.config, batch_size=self.config.micro_batch_size_to_train_on
1595-
),
1592+
mm_processor.get_dummy_audio_shape_for_init(self.config),
15961593
dtype=jnp.float32,
15971594
)
15981595
_, cache = self.model.apply(

src/MaxText/multimodal/preprocessor.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)