@@ -712,3 +712,86 @@ def __init__(self, ignored_ids, axis=1):
712712
713713 def map (self , element ):
714714 return shift_and_refine (element , ignored_ids = self .ignored_ids , axis = self .axis )
715+
716+
717+ @dataclasses .dataclass
718+ class ComputeQwen3OmniPositions (grain .MapTransform ):
719+ """Computes 3D position IDs for Qwen3-Omni multimodal sequences.
720+
721+ This transform replaces the standard 1D sequential positions with 3D
722+ positions (temporal, height, width) for multimodal models like Qwen3-Omni.
723+
724+ For text-only sequences, all 3 dimensions receive the same sequential values.
725+ For multimodal sequences with vision/audio, vision tokens get true 3D positions
726+ and text tokens continue sequentially from max(vision_pos) + 1.
727+
728+ The actual position computation is delegated to multimodal_utils.get_rope_index(),
729+ which can be tested and modified independently.
730+ """
731+
732+ def __init__ (
733+ self ,
734+ data_column : str = "inputs" ,
735+ spatial_merge_size : int = 2 ,
736+ position_id_per_seconds : int = 25 ,
737+ use_audio_in_video : bool = False ,
738+ ):
739+ """Initialize the Qwen3-Omni position computation transform.
740+
741+ Args:
742+ data_column: Name of the data column to compute positions for (default: "inputs").
743+ spatial_merge_size: Number of patches merged spatially (e.g., 2 for 2x2→1).
744+ position_id_per_seconds: Temporal granularity (tokens per second, typically 25).
745+ use_audio_in_video: If True, audio tokens are interleaved with video tokens.
746+ """
747+ self .data_column = data_column
748+ self .spatial_merge_size = spatial_merge_size
749+ self .position_id_per_seconds = position_id_per_seconds
750+ self .use_audio_in_video = use_audio_in_video
751+
752+ def map (self , element : dict [str , np .ndarray ]) -> dict [str , np .ndarray ]:
753+ """Compute 3D position IDs for the batch element.
754+
755+ Args:
756+ element: Dictionary containing:
757+ - {data_column}: Token IDs with shape (batch, seq_len)
758+ - {data_column}_segmentation: Attention mask (1=real, 0=padding)
759+ - image_grid_thw: Optional (num_images, 3) array
760+ - video_grid_thw: Optional (num_videos, 3) array
761+ - audio_lengths: Optional (num_audios,) array
762+ - second_per_grids: Optional (num_videos,) array
763+
764+ Returns:
765+ element with {data_column}_position updated to shape (3, batch, seq_len)
766+ for 3D positions (always 3D, even for text-only sequences).
767+ """
768+
769+ # Extract inputs and metadata
770+ input_ids = element [self .data_column ]
771+ attention_mask = element .get (f"{ self .data_column } _segmentation" )
772+
773+ # Extract multimodal metadata (if present)
774+ image_grid_thw = element .get ("image_grid_thw" )
775+ video_grid_thw = element .get ("video_grid_thw" )
776+ audio_lengths = element .get ("audio_lengths" )
777+ second_per_grids = element .get ("second_per_grids" )
778+
779+ # Call the standalone get_rope_index function from multimodal_utils
780+ position_ids , mrope_position_deltas = multimodal_utils .get_rope_index (
781+ input_ids = input_ids ,
782+ image_grid_thw = image_grid_thw ,
783+ video_grid_thw = video_grid_thw ,
784+ attention_mask = attention_mask ,
785+ use_audio_in_video = self .use_audio_in_video ,
786+ audio_lengths = audio_lengths ,
787+ second_per_grids = second_per_grids ,
788+ spatial_merge_size = self .spatial_merge_size ,
789+ position_id_per_seconds = self .position_id_per_seconds ,
790+ )
791+
792+ # Update element with 3D positions
793+ # Shape: (3, batch, seq_len) for multimodal, or (batch, seq_len) for text-only
794+ element [f"{ self .data_column } _position" ] = position_ids
795+ element [f"{ self .data_column } _mrope_deltas" ] = mrope_position_deltas
796+
797+ return element
0 commit comments