Skip to content

Commit 5e78bbf

Browse files
eitanporathengtaoguo
authored andcommitted
Add Preprocessing and token placeholders
1 parent 40071fc commit 5e78bbf

11 files changed

Lines changed: 168 additions & 46 deletions

File tree

benchmarks/api_server/maxtext_generator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -505,15 +505,13 @@ def _preprocess_inputs(self, text, prefill_length, image_path):
505505
num_images=1,
506506
)
507507
processor_output = mm_processor.preprocess_mm_data(self.config)
508-
prefill_length -= mm_processor.get_image_offsets(self.config.model_name, processor_output=processor_output)
508+
prefill_length -= mm_processor.get_image_offsets(config=self.config, processor_output=processor_output)
509509
images = processor_output.pixel_values
510510

511511
tokens, true_length = self.tokenizer.encode(text, is_bos=not self.has_chat_template, prefill_lengths=[prefill_length])
512512
if self.config.use_multimodal and image_path:
513-
tokens = mm_processor.prepare_text_for_image_fusion(
514-
tokens, model_name=self.config.model_name, processor_output=processor_output
515-
)
516-
true_length += mm_processor.get_image_offsets(self.config.model_name, processor_output=processor_output)
513+
tokens = mm_processor.prepare_text_for_image_fusion(tokens, config=self.config, processor_output=processor_output)
514+
true_length += mm_processor.get_image_offsets(config=self.config, processor_output=processor_output)
517515

518516
return tokens, true_length, images
519517

src/maxtext/configs/base.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,9 +992,11 @@ dtype_mm: "float32" # Data type for multimodal model's vision encoder
992992
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
993993
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
994994
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
995-
image_placeholder: "<|image|>"
996995
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
997996
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
997+
image_placeholder: "<|image|>"
998+
video_placeholder: "<|video|>"
999+
audio_placeholder: "<|audio|>"
9981000
use_audio_in_video: False
9991001
posemb_type_for_vit: "learn"
10001002
# max_num_images_per_example only applies for training when your image column is a list of images.

src/maxtext/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,4 @@ max_sample_len_for_audio: 10000
7777
# MRoPE Settings (Multi-dimensional RoPE for multimodal)
7878
use_mrope: true
7979
mrope_section: [24, 20, 20]
80-
position_id_per_seconds: 25
80+
position_id_per_seconds: 13

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,8 @@ class MultimodalGeneral(BaseModel):
14571457
)
14581458
video_path: PathStr = Field("", description="Path to a video for decoding.")
14591459
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
1460+
video_placeholder: str = Field("<|video|>", description="Placeholder string for video in text prompts.")
1461+
audio_placeholder: str = Field("<|audio|>", description="Placeholder string for audio in text prompts.")
14601462
use_audio_in_video: bool = Field(False, description="Extract and use audio from video files.")
14611463
use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.")
14621464
mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.")

src/maxtext/decode.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,16 @@ def main(argv: Sequence[str]) -> None:
104104
processor_outputs = mm_utils.PreprocessorOutput()
105105
if config.use_multimodal:
106106
processor_outputs = mm_processor.preprocess_mm_data(config)
107-
image_offsets = mm_processor.get_image_offsets(config.model_name, processor_output=processor_outputs)
107+
image_offsets = mm_processor.get_image_offsets(config=config, processor_output=processor_outputs)
108108

109109
prefill_length -= image_offsets
110110
text = mm_processor.reformat_prompt(
111111
prompt=config.prompt,
112112
image_placeholder=config.image_placeholder,
113+
video_placeholder=config.video_placeholder,
113114
model_name=config.model_name,
114115
num_images=processor_outputs.num_images,
116+
num_videos=getattr(processor_outputs, 'num_videos', 0),
115117
)
116118

117119
metadata = engine.get_tokenizer()
@@ -135,9 +137,7 @@ def main(argv: Sequence[str]) -> None:
135137
mrope_position_deltas = None
136138

137139
if config.use_multimodal:
138-
tokens = mm_processor.prepare_text_for_image_fusion(
139-
tokens, model_name=config.model_name, processor_output=processor_outputs
140-
)
140+
tokens = mm_processor.prepare_text_for_image_fusion(tokens=tokens, config=config, processor_output=processor_outputs)
141141
true_length += image_offsets
142142

143143
if config.use_mrope:
@@ -148,7 +148,7 @@ def main(argv: Sequence[str]) -> None:
148148
image_grid_thw=processor_outputs.pixel_grid_thw, # pytype: disable=attribute-error
149149
video_grid_thw=processor_outputs.video_grid_thw, # pytype: disable=attribute-error
150150
attention_mask=np.ones_like(tokens),
151-
use_audio_in_video=config.use_audio and processor_outputs.num_videos > 0, # pytype: disable=attribute-error
151+
use_audio_in_video=config.use_audio and getattr(processor_outputs, 'num_videos', 0) > 0,
152152
audio_lengths=processor_outputs.audio_lengths, # pytype: disable=attribute-error
153153
second_per_grids=processor_outputs.video_second_per_grid, # pytype: disable=attribute-error
154154
spatial_merge_size=config.spatial_merge_size_for_vit, # pytype: disable=attribute-error

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def vision_sft_preprocessing_pipeline(
126126
)
127127
dataset = dataset.map(
128128
input_pipeline_utils.prepare_text_for_image_fusion,
129-
fn_kwargs={"column_name": text_columns[0], "model_name": config.model_name},
129+
fn_kwargs={"column_name": text_columns[0], "config": config},
130130
)
131131

132132
dataset = input_pipeline_utils.HFDataSource(

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def _process_image_fn(image):
115115
return example
116116

117117

118-
def prepare_text_for_image_fusion(example, column_name, model_name):
118+
def prepare_text_for_image_fusion(example, column_name, config):
119119
"""prepare text for image fusion for multimodal SFT"""
120120
example[column_name] = mm_processor.prepare_text_for_image_fusion(
121-
example[column_name], model_name, processor_output=example["images"]
121+
tokens=example[column_name], config=config, processor_output=example["images"]
122122
)
123123
return example
124124

src/maxtext/multimodal/processor.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,25 @@ def preprocess_image_for_training(image, model_name):
6363
raise ValueError(f"Model {model_name} not supported for image preprocessing.")
6464

6565

66-
def get_image_offsets(model_name, processor_output: mm_utils.PreprocessorOutput | None):
66+
def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | None):
6767
"""Get the increase in total token count after inserting image token placeholders"""
68-
if model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
68+
if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
6969
from maxtext.multimodal.processor_gemma3 import get_image_offsets_gemma3 # pylint: disable=import-outside-toplevel
7070

7171
return get_image_offsets_gemma3(processor_output)
72-
elif model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
72+
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
7373
from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel
7474

7575
return get_image_offsets_llama4(processor_output)
76+
elif config.model_name in ["qwen3-omni-30b-a3b"]:
77+
from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel
78+
79+
return get_mm_offsets_qwen3_omni(config, processor_output)
7680
else:
7781
return 0
7882

7983

80-
def reformat_prompt(prompt, image_placeholder, model_name, num_images):
84+
def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_placeholder="<|video|>", num_videos=0):
8185
"""Reformat prompt for different models."""
8286
if model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
8387
from maxtext.multimodal.processor_gemma3 import reformat_prompt_gemma3 # pylint: disable=import-outside-toplevel
@@ -87,6 +91,16 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images):
8791
from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel
8892

8993
return reformat_prompt_llama4(prompt, image_placeholder, num_images)
94+
elif model_name in ["qwen3-omni-30b-a3b"]:
95+
from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel
96+
97+
return reformat_prompt_qwen3_omni(
98+
prompt=prompt,
99+
image_placeholder=image_placeholder,
100+
num_images=num_images,
101+
video_placeholder=video_placeholder,
102+
num_videos=num_videos,
103+
)
90104
else:
91105
return prompt
92106

@@ -99,22 +113,29 @@ def reformat_response(response, model_name):
99113
elif model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
100114
formatted_response = f"{response}<end_of_turn>"
101115
return formatted_response
116+
elif model_name in ["qwen3-omni-30b-a3b"]:
117+
formatted_response = f"{response}<|im_end|>"
118+
return formatted_response
102119
else:
103120
return response
104121

105122

106-
def prepare_text_for_image_fusion(texts, model_name, processor_output=None):
123+
def prepare_text_for_image_fusion(tokens, config, processor_output=None):
107124
"""Prepare text by adding extra tokens for image fusion based on the model."""
108-
if model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
125+
if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
109126
from maxtext.multimodal.processor_gemma3 import add_extra_tokens_for_images_gemma3 # pylint: disable=import-outside-toplevel
110127

111-
return add_extra_tokens_for_images_gemma3(texts, max_num_images=processor_output.num_images)
112-
elif model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
128+
return add_extra_tokens_for_images_gemma3(tokens, max_num_images=processor_output.num_images)
129+
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
113130
from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel
114131

115-
return add_extra_tokens_for_images_llama4(texts, processor_output)
132+
return add_extra_tokens_for_images_llama4(tokens, processor_output)
133+
elif config.model_name in ["qwen3-omni-30b-a3b"]:
134+
from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel
135+
136+
return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output)
116137
else:
117-
raise ValueError(f"Model {model_name} does not support multimodal inference.")
138+
raise ValueError(f"Model {config.model_name} does not support multimodal inference.")
118139

119140

120141
def get_dummy_image_shape_for_init(model_name, batch_size=1, num_image_per_sequence=1):

src/maxtext/multimodal/processor_qwen3_omni.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@
5959

6060
# Qwen3OmniMoe-specific processing
6161
QWEN3_OMNI_VISION_START_TOKEN = 151652 # <|vision_start|>
62-
QWEN3_OMNI_VISION_END_TOKEN = 151653 # <|vision_eos|>
62+
QWEN3_OMNI_VISION_END_TOKEN = 151653 # <|vision_end|>
6363
QWEN3_OMNI_IMAGE_TOKEN = 151655 # <|image_pad|>
6464
QWEN3_OMNI_VIDEO_TOKEN = 151656 # <|video_pad|>
6565
QWEN3_OMNI_AUDIO_START_TOKEN = 151669 # <|audio_start|>
66-
QWEN3_OMNI_AUDIO_END_TOKEN = 151648 # <|audio_eos|>
66+
QWEN3_OMNI_AUDIO_END_TOKEN = 151670 # <|audio_end|>
6767
QWEN3_OMNI_AUDIO_TOKEN = 151675 # <|audio_pad|>
6868
QWEN3_TEMPORAL_PATCH_SIZE = 2
6969
QWEN3_OMNI_IMAGE_SIZE = 768
@@ -90,6 +90,7 @@ class Qwen3OmniPreprocessorOutput(mm_utils.PreprocessorOutput):
9090
num_audios: int = 0
9191
audio_values: None | np.ndarray = None
9292
audio_mask: None | np.ndarray = None
93+
audio_lengths: None | np.ndarray = None
9394

9495

9596
def smart_resize(
@@ -477,41 +478,36 @@ def preprocess_mm_data_qwen3_omni(config):
477478
"""Placeholder for multimodal data preprocessing."""
478479
processor_outputs = Qwen3OmniPreprocessorOutput()
479480

480-
if config.image_path is not None:
481+
if config.image_path:
481482
images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")]
482483
pixel_values, pixel_grid_thw = pre_process_qwen3_image(images, config)
483484
processor_outputs.pixel_values = pixel_values
484485
processor_outputs.pixel_grid_thw = pixel_grid_thw
485486
processor_outputs.num_images = len(images)
486487

487-
if config.video_path is not None:
488+
if config.video_path:
488489
video_array, _ = _read_video_decord(config.video_path)
489490
video_processed, video_grid_thw = preprocess_video(video_array, config)
490491
processor_outputs.video_values = video_processed
491492
processor_outputs.video_grid_thw = video_grid_thw
492493
processor_outputs.video_second_per_grid = np.asarray([config.temporal_patch_size_for_vit], dtype=np.float32)
493494
processor_outputs.num_videos = 1 # Only one video for now.
494495

495-
if config.video_path is not None and config.use_audio_in_video:
496+
if config.video_path and config.use_audio_in_video:
496497
# TODO(hengtaoguo): add support for separate audio files. Now only extract audio from video files.
497498
mt_audio = mm_utils.load_audio(config.video_path, sample_rate=SAMPLE_RATE)
498499
mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio)
499500
processor_outputs.audio_values = mt_audio
500501
processor_outputs.audio_mask = mt_audio_mask
502+
# Compute audio_lengths from audio_mask
503+
audio_mask_sum = np.sum(mt_audio_mask, axis=-1)
504+
audio_lengths = _get_feat_extract_output_lengths(audio_mask_sum)
505+
processor_outputs.audio_lengths = np.array(audio_lengths, dtype=np.int32)
501506

502507
return processor_outputs
503508

504509

505-
def add_extra_tokens_for_qwen3_omni(
506-
tokens: np.ndarray | list,
507-
image_grid_thw: np.ndarray | None = None,
508-
video_grid_thw: np.ndarray | None = None,
509-
audio_lengths: np.ndarray | None = None,
510-
spatial_merge_size: int = 2,
511-
use_audio_in_video: bool = False,
512-
second_per_grids: np.ndarray | None = None,
513-
position_id_per_seconds: int = 25,
514-
):
510+
def add_extra_tokens_for_qwen3_omni(tokens, config, processor_output):
515511
"""Add extra tokens for Qwen3-Omni multimodal sequences.
516512
517513
Expands special tokens (<|image_pad|>, <|video_pad|>, <|audio_pad|>) into
@@ -532,6 +528,13 @@ def add_extra_tokens_for_qwen3_omni(
532528
Returns:
533529
Expanded token sequence with correct number of image/video/audio tokens.
534530
"""
531+
image_grid_thw = getattr(processor_output, "pixel_grid_thw", None)
532+
video_grid_thw = getattr(processor_output, "video_grid_thw", None)
533+
audio_lengths = getattr(processor_output, "audio_lengths", None)
534+
second_per_grids = getattr(processor_output, "video_second_per_grid", None)
535+
spatial_merge_size = config.spatial_merge_size_for_vit
536+
position_id_per_seconds = config.position_id_per_seconds
537+
535538
if not isinstance(tokens, np.ndarray):
536539
tokens = np.asarray(tokens)
537540

@@ -561,7 +564,7 @@ def add_extra_tokens_for_qwen3_omni(
561564

562565
# Handle audio-in-video: <|vision_start|><|video_pad|><|vision_end|>
563566
elif (
564-
use_audio_in_video
567+
config.use_audio_in_video
565568
and token == QWEN3_OMNI_VISION_START_TOKEN
566569
and i + 2 < len(token_list)
567570
and token_list[i + 1] == QWEN3_OMNI_VIDEO_TOKEN
@@ -1039,3 +1042,63 @@ def get_rope_index(
10391042
mrope_position_deltas = np.array(mrope_position_deltas).reshape(batch_size, 1)
10401043

10411044
return position_ids, mrope_position_deltas
1045+
1046+
1047+
def reformat_prompt_qwen3_omni(
1048+
prompt, image_placeholder="<|image|>", num_images=0, video_placeholder="<|video|>", num_videos=0
1049+
):
1050+
"""Reformat the prompt for Qwen3-Omni model."""
1051+
# Qwen3-Omni vision format: <|vision_start|><|image_pad|><|vision_end|>
1052+
# Qwen3-Omni mm token order: image_pad, video_pad, audio_pad (standalone audios), then text tokens.
1053+
# use_audio_in_video mode: such audio tokens are interleaved within video tokens.
1054+
qwen3_image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
1055+
qwen3_video_placeholder = "<|vision_start|><|video_pad|><|vision_end|>"
1056+
1057+
if video_placeholder in prompt:
1058+
prompt = prompt.replace(video_placeholder, qwen3_video_placeholder)
1059+
video_placeholder_count = prompt.count(qwen3_video_placeholder)
1060+
if video_placeholder_count < num_videos:
1061+
prompt = qwen3_video_placeholder * (num_videos - video_placeholder_count) + prompt
1062+
1063+
if image_placeholder in prompt:
1064+
prompt = prompt.replace(image_placeholder, qwen3_image_placeholder)
1065+
image_placeholder_count = prompt.count(qwen3_image_placeholder)
1066+
if image_placeholder_count < num_images:
1067+
prompt = qwen3_image_placeholder * (num_images - image_placeholder_count) + prompt
1068+
1069+
# Qwen chat template
1070+
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
1071+
return formatted_prompt
1072+
1073+
1074+
def get_mm_offsets_qwen3_omni(config, processor_output):
1075+
"""Calculate the token offsets for multimodal tokens in Qwen3-Omni model."""
1076+
# Calculate token expansion for Qwen3-Omni multimodal inputs
1077+
if processor_output is None:
1078+
return 0
1079+
1080+
total_offset = 0
1081+
spatial_merge_size = config.spatial_merge_size_for_vit # Default 2 for Qwen3-Omni
1082+
merge_length = spatial_merge_size**2
1083+
1084+
# Image tokens: <|image_pad|> expands to multiple image tokens
1085+
if processor_output.pixel_grid_thw is not None:
1086+
image_grid_thw = processor_output.pixel_grid_thw
1087+
for grid in image_grid_thw:
1088+
num_image_tokens = int((grid[0] * grid[1] * grid[2]) // merge_length)
1089+
total_offset += num_image_tokens - 1 # -1 for the original <|image_pad|> token
1090+
1091+
# Video tokens: <|video_pad|> expands to multiple video tokens
1092+
if processor_output.video_grid_thw is not None:
1093+
video_grid_thw = processor_output.video_grid_thw
1094+
for grid in video_grid_thw:
1095+
num_video_tokens = int((grid[0] * grid[1] * grid[2]) // merge_length)
1096+
total_offset += num_video_tokens - 1 # -1 for the original <|video_pad|> token
1097+
1098+
# Audio tokens: <|audio_pad|> expands based on audio_lengths
1099+
if processor_output.audio_lengths is not None:
1100+
audio_lengths = processor_output.audio_lengths
1101+
for audio_len in audio_lengths:
1102+
total_offset += int(audio_len) - 1 # -1 for the original <|audio_pad|> token
1103+
1104+
return total_offset

tests/unit/multimodal_utils_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# limitations under the License.
1414

1515
""" Tests for the common MaxText utilities """
16+
import os
1617
import unittest
1718
import numpy as np
1819

20+
from MaxText import pyconfig
21+
from MaxText.globals import MAXTEXT_REPO_ROOT
1922
from maxtext.multimodal import processor as mm_processor
2023
from maxtext.multimodal import utils as mm_utils
2124
from maxtext.multimodal import processor_gemma3
@@ -195,8 +198,12 @@ def test_post_process_image_tokens(self):
195198
pixel_values=dummy_pixel_values,
196199
aspect_ratios=dummy_aspect_ratios,
197200
)
198-
199-
image_offsets = mm_processor.get_image_offsets(model_name=self.model_name, processor_output=processor_output)
201+
base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml")
202+
config = pyconfig.initialize(
203+
["", base_config_path],
204+
model_name="llama4-17b-16e",
205+
)
206+
image_offsets = mm_processor.get_image_offsets(config=config, processor_output=processor_output)
200207
post_processed_tokens = processor_llama4.add_extra_tokens_for_images_llama4(dummy_tokens, processor_output)
201208
self.assertEqual(post_processed_tokens.shape[0], dummy_tokens.shape[0] + image_offsets)
202209

0 commit comments

Comments
 (0)