2424import numpy as np
2525import tensorflow as tf
2626from MaxText import tokenizer
27- from MaxText import multimodal_utils
27+ from maxtext .multimodal import processor as mm_processor
28+ from maxtext .multimodal import utils as mm_utils
2829from maxtext .utils import max_logging
2930
3031Features = dict [str , tf .Tensor ]
@@ -73,13 +74,13 @@ def reformat_prompt(example, column, image_placeholder, model_name):
7374 num_images = len (example ["images" ])
7475 else :
7576 num_images = 1
76- example [column ] = multimodal_utils .reformat_prompt (example [column ], image_placeholder , model_name , num_images )
77+ example [column ] = mm_processor .reformat_prompt (example [column ], image_placeholder , model_name , num_images )
7778 return example
7879
7980
8081def reformat_response (example , column , model_name ):
8182 """reformat response for multimodal SFT"""
82- example [column ] = multimodal_utils .reformat_response (example [column ][0 ], model_name )
83+ example [column ] = mm_processor .reformat_response (example [column ][0 ], model_name )
8384 return example
8485
8586
@@ -101,11 +102,11 @@ def pre_process_image_sft(example, image_column, model_name):
101102
102103 def _process_image_fn (image ):
103104 if isinstance (image , list ):
104- image = [np .array (multimodal_utils .convert_to_RGB (img )) for img in image ]
105+ image = [np .array (mm_utils .convert_to_RGB (img )) for img in image ]
105106 else :
106- image = np .array (multimodal_utils .convert_to_RGB (image ))
107+ image = np .array (mm_utils .convert_to_RGB (image ))
107108
108- image = multimodal_utils . pre_process_image (image , model_name )
109+ image = mm_processor . preprocess_image_for_training (image , model_name )
109110 return image
110111
111112 example [image_column ] = _process_image_fn (example [image_column ])
@@ -114,7 +115,7 @@ def _process_image_fn(image):
114115
115116def prepare_text_for_image_fusion (example , column_name , model_name ):
116117 """prepare text for image fusion for multimodal SFT"""
117- example [column_name ] = multimodal_utils .prepare_text_for_image_fusion (
118+ example [column_name ] = mm_processor .prepare_text_for_image_fusion (
118119 example [column_name ], model_name , processor_output = example ["images" ]
119120 )
120121 return example
@@ -478,9 +479,7 @@ def _pad_text(self, x: np.ndarray, max_length: int, pad_id: int) -> np.ndarray:
478479 pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
479480 return np .pad (x , pad_amount , constant_values = pad_id )[: self .max_length ]
480481
481- def _pad_image_and_mask (
482- self , preprocessed_image : multimodal_utils .PreprocessorOutput
483- ) -> multimodal_utils .PreprocessorOutput :
482+ def _pad_image_and_mask (self , preprocessed_image : mm_utils .PreprocessorOutput ) -> mm_utils .PreprocessorOutput :
484483 """Pads the input tensors (image and mask) of a PreprocessorOutput to a maximum number of items.
485484
486485 This function unifies padding logic for image tensors (standard or tiled) and
@@ -513,14 +512,14 @@ def _pad_image_and_mask(
513512 - The dummy images used for padding are based on the image shape for initialization
514513 of this model (ignoring batch size).
515514 """
516- if not isinstance (preprocessed_image , multimodal_utils .PreprocessorOutput ):
515+ if not isinstance (preprocessed_image , mm_utils .PreprocessorOutput ):
517516 raise TypeError (f"Input must be multimodal_utils.PreprocessorOutput, but got { type (preprocessed_image )} " )
518517
519518 if preprocessed_image .pixel_values is None :
520519 raise ValueError ("Input preprocessed_image must have pixel_values to pad images." )
521520
522521 # Determine the maximum number of images/masks allowed.
523- image_offsets = multimodal_utils .get_image_offsets (self .model_name , preprocessed_image )
522+ image_offsets = mm_processor .get_image_offsets (self .model_name , preprocessed_image )
524523 single_image_offset = image_offsets // preprocessed_image .pixel_values .shape [0 ]
525524
526525 # Reserve space for at least one text token.
@@ -569,13 +568,13 @@ def _pad(tensor: np.ndarray) -> np.ndarray:
569568 return preprocessed_image
570569
571570 def map (
572- self , element : dict [str , np .ndarray | multimodal_utils .PreprocessorOutput ]
573- ) -> dict [str , np .ndarray | multimodal_utils .PreprocessorOutput ]:
571+ self , element : dict [str , np .ndarray | mm_utils .PreprocessorOutput ]
572+ ) -> dict [str , np .ndarray | mm_utils .PreprocessorOutput ]:
574573 """map to each element"""
575574 data_columns = list (element .keys ())
576575 for data_column in data_columns :
577576 if data_column != "images" :
578- if isinstance (element [data_column ], multimodal_utils .PreprocessorOutput ):
577+ if isinstance (element [data_column ], mm_utils .PreprocessorOutput ):
579578 raise TypeError ("Only 'images' column can be of type PreprocessorOutput." )
580579
581580 element [f"{ data_column } _segmentation" ] = element [data_column ] != self .pad_id
@@ -615,7 +614,7 @@ def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
615614 if preprocessed_image is None :
616615 return element
617616
618- if not isinstance (preprocessed_image , multimodal_utils .PreprocessorOutput ):
617+ if not isinstance (preprocessed_image , mm_utils .PreprocessorOutput ):
619618 raise TypeError (f"'images' must be of type PreprocessorOutput, but got { type (preprocessed_image )} " )
620619
621620 output = element .copy ()
@@ -646,7 +645,7 @@ class FoldImagesIntoBatch(grain.MapTransform):
646645
647646 def __post_init__ (self ):
648647 """Initializes the target shape after the dataclass is created."""
649- self .target_shape = multimodal_utils .get_dummy_image_shape_for_init (self .model_name )
648+ self .target_shape = mm_processor .get_dummy_image_shape_for_init (self .model_name )
650649
651650 def map (self , element : dict [str , np .ndarray ]) -> dict [str , np .ndarray ]:
652651 """Applies the folding transformation to the 'images' field if present."""
@@ -777,7 +776,10 @@ def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
777776 second_per_grids = element .get ("second_per_grids" )
778777
779778 # Call the standalone get_rope_index function from multimodal_utils
780- position_ids , mrope_position_deltas = multimodal_utils .get_rope_index (
779+ from maxtext .multimodal import processor_qwen3_omni # pylint: disable=import-outside-toplevel
780+
781+ # TODO(jfacevedo/hengtaoguo): Now get_rope_index is Qwen3-Omni specific. We should generalize it for other models
782+ position_ids , mrope_position_deltas = processor_qwen3_omni .get_rope_index (
781783 input_ids = input_ids ,
782784 image_grid_thw = image_grid_thw ,
783785 video_grid_thw = video_grid_thw ,
0 commit comments