5959
6060# Qwen3OmniMoe-specific processing
6161QWEN3_OMNI_VISION_START_TOKEN = 151652 # <|vision_start|>
62- QWEN3_OMNI_VISION_END_TOKEN = 151653 # <|vision_eos |>
62+ QWEN3_OMNI_VISION_END_TOKEN = 151653 # <|vision_end |>
6363QWEN3_OMNI_IMAGE_TOKEN = 151655 # <|image_pad|>
6464QWEN3_OMNI_VIDEO_TOKEN = 151656 # <|video_pad|>
6565QWEN3_OMNI_AUDIO_START_TOKEN = 151669 # <|audio_start|>
66- QWEN3_OMNI_AUDIO_END_TOKEN = 151648 # <|audio_eos |>
66+ QWEN3_OMNI_AUDIO_END_TOKEN = 151670 # <|audio_end |>
6767QWEN3_OMNI_AUDIO_TOKEN = 151675 # <|audio_pad|>
6868QWEN3_TEMPORAL_PATCH_SIZE = 2
6969QWEN3_OMNI_IMAGE_SIZE = 768
@@ -90,6 +90,7 @@ class Qwen3OmniPreprocessorOutput(mm_utils.PreprocessorOutput):
9090 num_audios : int = 0
9191 audio_values : None | np .ndarray = None
9292 audio_mask : None | np .ndarray = None
93+ audio_lengths : None | np .ndarray = None
9394
9495
9596def smart_resize (
@@ -477,41 +478,36 @@ def preprocess_mm_data_qwen3_omni(config):
477478 """Placeholder for multimodal data preprocessing."""
478479 processor_outputs = Qwen3OmniPreprocessorOutput ()
479480
480- if config .image_path is not None :
481+ if config .image_path :
481482 images = [mm_utils .load_image_from_path (p ) for p in config .image_path .split ("," )]
482483 pixel_values , pixel_grid_thw = pre_process_qwen3_image (images , config )
483484 processor_outputs .pixel_values = pixel_values
484485 processor_outputs .pixel_grid_thw = pixel_grid_thw
485486 processor_outputs .num_images = len (images )
486487
487- if config .video_path is not None :
488+ if config .video_path :
488489 video_array , _ = _read_video_decord (config .video_path )
489490 video_processed , video_grid_thw = preprocess_video (video_array , config )
490491 processor_outputs .video_values = video_processed
491492 processor_outputs .video_grid_thw = video_grid_thw
492493 processor_outputs .video_second_per_grid = np .asarray ([config .temporal_patch_size_for_vit ], dtype = np .float32 )
493494 processor_outputs .num_videos = 1 # Only one video for now.
494495
495- if config .video_path is not None and config .use_audio_in_video :
496+ if config .video_path and config .use_audio_in_video :
496497 # TODO(hengtaoguo): add support for separate audio files. Now only extract audio from video files.
497498 mt_audio = mm_utils .load_audio (config .video_path , sample_rate = SAMPLE_RATE )
498499 mt_audio , mt_audio_mask = pre_process_audio_qwen3_omni (mt_audio )
499500 processor_outputs .audio_values = mt_audio
500501 processor_outputs .audio_mask = mt_audio_mask
502+ # Compute audio_lengths from audio_mask
503+ audio_mask_sum = np .sum (mt_audio_mask , axis = - 1 )
504+ audio_lengths = _get_feat_extract_output_lengths (audio_mask_sum )
505+ processor_outputs .audio_lengths = np .array (audio_lengths , dtype = np .int32 )
501506
502507 return processor_outputs
503508
504509
505- def add_extra_tokens_for_qwen3_omni (
506- tokens : np .ndarray | list ,
507- image_grid_thw : np .ndarray | None = None ,
508- video_grid_thw : np .ndarray | None = None ,
509- audio_lengths : np .ndarray | None = None ,
510- spatial_merge_size : int = 2 ,
511- use_audio_in_video : bool = False ,
512- second_per_grids : np .ndarray | None = None ,
513- position_id_per_seconds : int = 25 ,
514- ):
510+ def add_extra_tokens_for_qwen3_omni (tokens , config , processor_output ):
515511 """Add extra tokens for Qwen3-Omni multimodal sequences.
516512
517513 Expands special tokens (<|image_pad|>, <|video_pad|>, <|audio_pad|>) into
@@ -532,6 +528,13 @@ def add_extra_tokens_for_qwen3_omni(
532528 Returns:
533529 Expanded token sequence with correct number of image/video/audio tokens.
534530 """
531+ image_grid_thw = getattr (processor_output , "pixel_grid_thw" , None )
532+ video_grid_thw = getattr (processor_output , "video_grid_thw" , None )
533+ audio_lengths = getattr (processor_output , "audio_lengths" , None )
534+ second_per_grids = getattr (processor_output , "video_second_per_grid" , None )
535+ spatial_merge_size = config .spatial_merge_size_for_vit
536+ position_id_per_seconds = config .position_id_per_seconds
537+
535538 if not isinstance (tokens , np .ndarray ):
536539 tokens = np .asarray (tokens )
537540
@@ -561,7 +564,7 @@ def add_extra_tokens_for_qwen3_omni(
561564
562565 # Handle audio-in-video: <|vision_start|><|video_pad|><|vision_end|>
563566 elif (
564- use_audio_in_video
567+ config . use_audio_in_video
565568 and token == QWEN3_OMNI_VISION_START_TOKEN
566569 and i + 2 < len (token_list )
567570 and token_list [i + 1 ] == QWEN3_OMNI_VIDEO_TOKEN
@@ -1039,3 +1042,63 @@ def get_rope_index(
10391042 mrope_position_deltas = np .array (mrope_position_deltas ).reshape (batch_size , 1 )
10401043
10411044 return position_ids , mrope_position_deltas
1045+
1046+
1047+ def reformat_prompt_qwen3_omni (
1048+ prompt , image_placeholder = "<|image|>" , num_images = 0 , video_placeholder = "<|video|>" , num_videos = 0
1049+ ):
1050+ """Reformat the prompt for Qwen3-Omni model."""
1051+ # Qwen3-Omni vision format: <|vision_start|><|image_pad|><|vision_end|>
1052+ # Qwen3-Omni mm token order: image_pad, video_pad, audio_pad (standalone audios), then text tokens.
1053+ # use_audio_in_video mode: such audio tokens are interleaved within video tokens.
1054+ qwen3_image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
1055+ qwen3_video_placeholder = "<|vision_start|><|video_pad|><|vision_end|>"
1056+
1057+ if video_placeholder in prompt :
1058+ prompt = prompt .replace (video_placeholder , qwen3_video_placeholder )
1059+ video_placeholder_count = prompt .count (qwen3_video_placeholder )
1060+ if video_placeholder_count < num_videos :
1061+ prompt = qwen3_video_placeholder * (num_videos - video_placeholder_count ) + prompt
1062+
1063+ if image_placeholder in prompt :
1064+ prompt = prompt .replace (image_placeholder , qwen3_image_placeholder )
1065+ image_placeholder_count = prompt .count (qwen3_image_placeholder )
1066+ if image_placeholder_count < num_images :
1067+ prompt = qwen3_image_placeholder * (num_images - image_placeholder_count ) + prompt
1068+
1069+ # Qwen chat template
1070+ formatted_prompt = f"<|im_start|>user\n { prompt } <|im_end|>\n <|im_start|>assistant\n "
1071+ return formatted_prompt
1072+
1073+
1074+ def get_mm_offsets_qwen3_omni (config , processor_output ):
1075+ """Calculate the token offsets for multimodal tokens in Qwen3-Omni model."""
1076+ # Calculate token expansion for Qwen3-Omni multimodal inputs
1077+ if processor_output is None :
1078+ return 0
1079+
1080+ total_offset = 0
1081+ spatial_merge_size = config .spatial_merge_size_for_vit # Default 2 for Qwen3-Omni
1082+ merge_length = spatial_merge_size ** 2
1083+
1084+ # Image tokens: <|image_pad|> expands to multiple image tokens
1085+ if processor_output .pixel_grid_thw is not None :
1086+ image_grid_thw = processor_output .pixel_grid_thw
1087+ for grid in image_grid_thw :
1088+ num_image_tokens = int ((grid [0 ] * grid [1 ] * grid [2 ]) // merge_length )
1089+ total_offset += num_image_tokens - 1 # -1 for the original <|image_pad|> token
1090+
1091+ # Video tokens: <|video_pad|> expands to multiple video tokens
1092+ if processor_output .video_grid_thw is not None :
1093+ video_grid_thw = processor_output .video_grid_thw
1094+ for grid in video_grid_thw :
1095+ num_video_tokens = int ((grid [0 ] * grid [1 ] * grid [2 ]) // merge_length )
1096+ total_offset += num_video_tokens - 1 # -1 for the original <|video_pad|> token
1097+
1098+ # Audio tokens: <|audio_pad|> expands based on audio_lengths
1099+ if processor_output .audio_lengths is not None :
1100+ audio_lengths = processor_output .audio_lengths
1101+ for audio_len in audio_lengths :
1102+ total_offset += int (audio_len ) - 1 # -1 for the original <|audio_pad|> token
1103+
1104+ return total_offset
0 commit comments