Skip to content

Commit 2159a00

Browse files
entrpneitanporat
andcommitted
Add MROPE support.
Co-authored-by: Eitan Porat <eporat@lightricks.com>
1 parent 33209a8 commit 2159a00

13 files changed

Lines changed: 1204 additions & 49 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,10 @@ num_conv_layers_for_audio: 3
10131013
max_timescale_for_audio: 10000.0
10141014
max_sample_len_for_audio: 10000
10151015

1016+
use_mrope: false
1017+
mrope_section: [24, 20, 20]
1018+
position_id_per_seconds: 25
1019+
10161020
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
10171021
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
10181022
subslice_shape: ""

src/MaxText/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@ conv_chunksize_for_audio: 500
7474
num_conv_layers_for_audio: 3
7575
max_timescale_for_audio: 10000.0
7676
max_sample_len_for_audio: 10000
77+
# MRoPE Settings (Multi-dimensional RoPE for multimodal)
78+
use_mrope: true
79+
mrope_section: [24, 20, 20]
80+
position_id_per_seconds: 25

src/MaxText/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,9 @@ class MultimodalGeneral(BaseModel):
13941394
video_path: PathStr = Field("", description="Path to a video for decoding.")
13951395
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
13961396
use_audio_in_video: bool = Field(False, description="Extract and use audio from video files.")
1397+
use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.")
1398+
mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.")
1399+
position_id_per_seconds: int = Field(25, description="Temporal granularity for MRoPE (tokens per second).")
13971400

13981401

13991402
class VisionTower(BaseModel):

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,86 @@ def __init__(self, ignored_ids, axis=1):
712712

713713
def map(self, element):
714714
return shift_and_refine(element, ignored_ids=self.ignored_ids, axis=self.axis)
715+
716+
717+
@dataclasses.dataclass
718+
class ComputeQwen3OmniPositions(grain.MapTransform):
719+
"""Computes 3D position IDs for Qwen3-Omni multimodal sequences.
720+
721+
This transform replaces the standard 1D sequential positions with 3D
722+
positions (temporal, height, width) for multimodal models like Qwen3-Omni.
723+
724+
For text-only sequences, all 3 dimensions receive the same sequential values.
725+
For multimodal sequences with vision/audio, vision tokens get true 3D positions
726+
and text tokens continue sequentially from max(vision_pos) + 1.
727+
728+
The actual position computation is delegated to multimodal_utils.get_rope_index(),
729+
which can be tested and modified independently.
730+
"""
731+
732+
def __init__(
733+
self,
734+
data_column: str = "inputs",
735+
spatial_merge_size: int = 2,
736+
position_id_per_seconds: int = 25,
737+
use_audio_in_video: bool = False,
738+
):
739+
"""Initialize the Qwen3-Omni position computation transform.
740+
741+
Args:
742+
data_column: Name of the data column to compute positions for (default: "inputs").
743+
spatial_merge_size: Number of patches merged spatially (e.g., 2 for 2x2→1).
744+
position_id_per_seconds: Temporal granularity (tokens per second, typically 25).
745+
use_audio_in_video: If True, audio tokens are interleaved with video tokens.
746+
"""
747+
self.data_column = data_column
748+
self.spatial_merge_size = spatial_merge_size
749+
self.position_id_per_seconds = position_id_per_seconds
750+
self.use_audio_in_video = use_audio_in_video
751+
752+
def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
753+
"""Compute 3D position IDs for the batch element.
754+
755+
Args:
756+
element: Dictionary containing:
757+
- {data_column}: Token IDs with shape (batch, seq_len)
758+
- {data_column}_segmentation: Attention mask (1=real, 0=padding)
759+
- image_grid_thw: Optional (num_images, 3) array
760+
- video_grid_thw: Optional (num_videos, 3) array
761+
- audio_lengths: Optional (num_audios,) array
762+
- second_per_grids: Optional (num_videos,) array
763+
764+
Returns:
765+
element with {data_column}_position updated to shape (3, batch, seq_len)
766+
for 3D positions (always 3D, even for text-only sequences).
767+
"""
768+
769+
# Extract inputs and metadata
770+
input_ids = element[self.data_column]
771+
attention_mask = element.get(f"{self.data_column}_segmentation")
772+
773+
# Extract multimodal metadata (if present)
774+
image_grid_thw = element.get("image_grid_thw")
775+
video_grid_thw = element.get("video_grid_thw")
776+
audio_lengths = element.get("audio_lengths")
777+
second_per_grids = element.get("second_per_grids")
778+
779+
# Call the standalone get_rope_index function from multimodal_utils
780+
position_ids, mrope_position_deltas = multimodal_utils.get_rope_index(
781+
input_ids=input_ids,
782+
image_grid_thw=image_grid_thw,
783+
video_grid_thw=video_grid_thw,
784+
attention_mask=attention_mask,
785+
use_audio_in_video=self.use_audio_in_video,
786+
audio_lengths=audio_lengths,
787+
second_per_grids=second_per_grids,
788+
spatial_merge_size=self.spatial_merge_size,
789+
position_id_per_seconds=self.position_id_per_seconds,
790+
)
791+
792+
# Update element with 3D positions
793+
# Shape: (3, batch, seq_len) for multimodal, or (batch, seq_len) for text-only
794+
element[f"{self.data_column}_position"] = position_ids
795+
element[f"{self.data_column}_mrope_deltas"] = mrope_position_deltas
796+
797+
return element

src/MaxText/layers/attentions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from MaxText.layers.embeddings import (
6060
LLaMARotaryEmbedding,
6161
LlamaVisionRotaryEmbedding,
62+
Qwen3OmniMoeThinkerTextRotaryEmbedding,
6263
Qwen3OmniMoeVisionRotaryEmbedding,
6364
RotaryEmbedding,
6465
YarnRotaryEmbedding,
@@ -160,6 +161,8 @@ def attention_as_linen(
160161
is_nope_layer: bool = False,
161162
is_vision: bool = False,
162163
model_mode: str = MODEL_MODE_TRAIN,
164+
use_mrope: bool = False,
165+
mrope_section: tuple[int, int, int] | None = None,
163166
name: str | None = None,
164167
):
165168
"""A factory function to create an Attention as a Linen module.
@@ -222,6 +225,8 @@ def attention_as_linen(
222225
is_nope_layer=is_nope_layer,
223226
is_vision=is_vision,
224227
model_mode=model_mode,
228+
use_mrope=use_mrope,
229+
mrope_section=mrope_section,
225230
name=name,
226231
metadata_fn=variable_to_logically_partitioned,
227232
abstract_init=False,
@@ -320,6 +325,8 @@ def __init__(
320325
is_vision: bool = False,
321326
model_mode: str = MODEL_MODE_TRAIN,
322327
base_kv_cache: bool = True,
328+
use_mrope: bool = False,
329+
mrope_section: tuple[int, int, int] | None = None,
323330
name: str | None = None,
324331
rngs: Optional[nnx.Rngs] = None,
325332
):
@@ -414,6 +421,8 @@ def __init__(
414421
self.is_nope_layer = is_nope_layer
415422
self.is_vision = is_vision
416423
self.model_mode = model_mode
424+
self.use_mrope = use_mrope
425+
self.mrope_section = mrope_section
417426
self.rngs = rngs
418427

419428
self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT
@@ -743,6 +752,17 @@ def init_rotary_embedding(self):
743752
else:
744753
raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}")
745754

755+
elif self.use_mrope:
756+
rotary_embedding = Qwen3OmniMoeThinkerTextRotaryEmbedding(
757+
min_timescale=self.config.rope_min_timescale,
758+
max_timescale=self.config.rope_max_timescale,
759+
embedding_dims=rope_embedding_dims,
760+
cast_as_fprop_dtype=True,
761+
fprop_dtype=self.dtype,
762+
mrope_section=self.mrope_section,
763+
rngs=self.rngs,
764+
)
765+
746766
elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"):
747767
rotary_embedding = LLaMARotaryEmbedding(
748768
min_timescale=self.config.rope_min_timescale,

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def __call__(
152152
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
153153
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
154154
reshape_q=cfg.reshape_q,
155+
use_mrope=cfg.use_mrope,
156+
mrope_section=cfg.mrope_section,
155157
model_mode=model_mode,
156158
)
157159

0 commit comments

Comments
 (0)