Skip to content

Commit 9fc7f95

Browse files
committed
Rename Qwen3NextRotaryEmbedding to PartialRotaryEmbedding to make more generic and reusable.
1 parent 96b72fb commit 9fc7f95

4 files changed

Lines changed: 178 additions & 10 deletions

File tree

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
Qwen3OmniMoeVisionRotaryEmbedding,
6363
RotaryEmbedding,
6464
YarnRotaryEmbedding,
65-
Qwen3NextRotaryEmbedding,
65+
PartialRotaryEmbedding,
6666
)
6767
from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init
6868
from maxtext.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes
@@ -814,7 +814,7 @@ def init_rotary_embedding(self):
814814
rngs=self.rngs,
815815
)
816816
elif self.is_qwen3_next:
817-
rotary_embedding = Qwen3NextRotaryEmbedding(
817+
rotary_embedding = PartialRotaryEmbedding(
818818
min_timescale=self.config.rope_min_timescale,
819819
max_timescale=self.config.rope_max_timescale,
820820
mesh=self.mesh,

src/maxtext/layers/embeddings.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def llama_rotary_embedding_as_linen(
406406
)
407407

408408

409-
def qwen3_next_rotary_embedding_as_linen(
409+
def partial_rotary_embedding_as_linen(
410410
*,
411411
min_timescale: int,
412412
max_timescale: int,
@@ -418,7 +418,7 @@ def qwen3_next_rotary_embedding_as_linen(
418418
shard_mode: ShardMode = ShardMode.AUTO,
419419
name: str | None = None,
420420
):
421-
"""Initializes the Qwen3NextRotaryEmbedding module and returns it as a Linen module.
421+
"""Initializes the PartialRotaryEmbedding module and returns it as a Linen module.
422422
423423
Args:
424424
min_timescale: Start of the geometric index. Determines the periodicity of
@@ -432,7 +432,7 @@ def qwen3_next_rotary_embedding_as_linen(
432432
name: Name of the Linen module.
433433
"""
434434
return nnx_wrappers.to_linen(
435-
Qwen3NextRotaryEmbedding,
435+
PartialRotaryEmbedding,
436436
min_timescale=min_timescale,
437437
max_timescale=max_timescale,
438438
mesh=mesh,
@@ -446,8 +446,8 @@ def qwen3_next_rotary_embedding_as_linen(
446446
)
447447

448448

449-
class Qwen3NextRotaryEmbedding(RotaryEmbedding):
450-
"""Qwen3 Next variant of ROPE (partial ROPE)"""
449+
class PartialRotaryEmbedding(RotaryEmbedding):
450+
"""Rotary Position Embedding applied to a partial fraction of dimensions."""
451451

452452
def __init__(
453453
self,
@@ -461,7 +461,7 @@ def __init__(
461461
shard_mode: ShardMode = ShardMode.AUTO,
462462
rngs: nnx.Rngs = None,
463463
):
464-
"""Initializes the Qwen3NextRotaryEmbedding module.
464+
"""Initializes the PartialRotaryEmbedding module.
465465
466466
Args:
467467
min_timescale: Start of the geometric index. Determines the periodicity of
@@ -476,6 +476,7 @@ def __init__(
476476
self.partial_rotary_factor = partial_rotary_factor
477477
self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
478478

479+
# Initialize the base class with only the rotary_dim
479480
super().__init__(
480481
min_timescale=min_timescale,
481482
max_timescale=max_timescale,
@@ -488,7 +489,7 @@ def __init__(
488489
)
489490

490491
def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array:
491-
"""Applies LLaMA variant of rotary position embedding.
492+
"""Applies Partial variant of rotary position embedding.
492493
493494
Args:
494495
inputs: The input sequence on which to apply the Rotary position
@@ -499,6 +500,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
499500
Returns:
500501
A jax.Array of shape [B, S, H, D - rotary_dim] with rotary position embeddings applied.
501502
"""
503+
# Split, apply base RoPE to the first fraction, and concatenate
502504
inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1)
503505
inputs_rot = super().__call__(inputs_rot, position)
504506
inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1)

src/maxtext/models/qwen3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ class Qwen3NextFullAttention(nnx.Module):
615615
- Query and Gate splitting from a single q projection.
616616
- Application of a sigmoid gate to the attention output.
617617
- Usage of `Qwen3NextRMSNorm` for query and key normalization.
618-
- Usage of `Qwen3NextRotaryEmbedding` for partial rotary position embeddings.
618+
- Usage of `PartialRotaryEmbedding` for partial rotary position embeddings.
619619
- Partial ROPE is applied to the first 25% of head dimensions
620620
621621
Attributes:
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for the partial rotary position embedding layer.
16+
17+
The new PartialRotaryEmbedding class is a thin wrapper around
18+
`RotaryEmbedding` that applies RoPE only to the first fraction of the
19+
hidden dimensions. The tests below exercise the half/fully-rotated
20+
cases and verify basic shift invariance in the same style used by our
21+
existing rotary unit tests.
22+
"""
23+
24+
import sys
25+
import unittest
26+
27+
import jax
28+
import jax.numpy as jnp
29+
from jax.sharding import Mesh
30+
from flax import nnx
31+
import numpy as np
32+
33+
from maxtext.layers.embeddings import PartialRotaryEmbedding, RotaryEmbedding
34+
from maxtext.configs import pyconfig
35+
from maxtext.utils import maxtext_utils
36+
from tests.utils.test_helpers import get_test_config_path
37+
38+
39+
class PartialRotaryEmbeddingTest(unittest.TestCase):
40+
"""Tests for the PartialRotaryEmbedding layer."""
41+
42+
def setUp(self):
43+
super().setUp()
44+
# build a simple config and mesh like other embedding tests
45+
self.cfg = pyconfig.initialize(
46+
[sys.argv[0], get_test_config_path()],
47+
run_name="test_embeddings",
48+
enable_checkpointing=False,
49+
)
50+
devices_array = maxtext_utils.create_device_mesh(self.cfg)
51+
self.mesh = Mesh(devices_array, self.cfg.mesh_axes)
52+
self.nnx_rng = nnx.Rngs(params=0)
53+
54+
def test_partial_rotary_half(self):
55+
"""The first half of the hidden dim should be rotated, the rest passthrough."""
56+
batch_size, seq_len, num_heads, head_dim = 2, 16, 4, 8
57+
inputs = jax.random.normal(jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim))
58+
positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, seq_len)
59+
60+
rope_half = PartialRotaryEmbedding(
61+
min_timescale=1,
62+
max_timescale=10000,
63+
mesh=self.mesh,
64+
embedding_dims=head_dim,
65+
partial_rotary_factor=0.5,
66+
rngs=self.nnx_rng,
67+
cast_as_fprop_dtype=False,
68+
)
69+
70+
y_half = rope_half(inputs, positions)
71+
72+
rotary_dim = head_dim // 2
73+
inputs_rot, inputs_pass = inputs[..., :rotary_dim], inputs[..., rotary_dim:]
74+
75+
rope_full_for_rot_part = RotaryEmbedding(
76+
min_timescale=1,
77+
max_timescale=10000,
78+
mesh=self.mesh,
79+
embedding_dims=rotary_dim,
80+
rngs=self.nnx_rng,
81+
cast_as_fprop_dtype=False,
82+
)
83+
y_rot_expected = rope_full_for_rot_part(inputs_rot, positions)
84+
85+
np.testing.assert_allclose(
86+
y_half[..., :rotary_dim],
87+
y_rot_expected,
88+
rtol=1e-6,
89+
atol=1e-6,
90+
err_msg="First fraction should be rotated.",
91+
)
92+
np.testing.assert_allclose(
93+
y_half[..., rotary_dim:],
94+
inputs_pass,
95+
rtol=1e-6,
96+
atol=1e-6,
97+
err_msg="Remaining dims should pass through.",
98+
)
99+
100+
def test_partial_rotary_full(self):
101+
"""A partial factor of 1.0 is equivalent to the base rotary embedding."""
102+
batch_size, seq_len, num_heads, head_dim = 1, 4, 4, 8
103+
inputs = jax.random.normal(jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim))
104+
positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, seq_len)
105+
106+
rope_partial = PartialRotaryEmbedding(
107+
min_timescale=1,
108+
max_timescale=10000,
109+
mesh=self.mesh,
110+
embedding_dims=head_dim,
111+
partial_rotary_factor=1.0,
112+
rngs=self.nnx_rng,
113+
)
114+
y_partial = rope_partial(inputs, positions)
115+
116+
rope_full = RotaryEmbedding(
117+
min_timescale=1,
118+
max_timescale=10000,
119+
mesh=self.mesh,
120+
embedding_dims=head_dim,
121+
rngs=self.nnx_rng,
122+
)
123+
y_full = rope_full(inputs, positions)
124+
125+
np.testing.assert_allclose(
126+
y_partial,
127+
y_full,
128+
rtol=1e-6,
129+
atol=1e-6,
130+
err_msg="PartialRotaryEmbedding with factor=1 should equal full rotary.",
131+
)
132+
133+
def test_shift_invariance(self):
134+
"""Verify that rotary attention computed from partial embedding is shift invariant."""
135+
batch_size, seq_len, num_heads, head_dim = 1, 20, 4, 8
136+
inputs = jax.random.normal(jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim))
137+
positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, seq_len)
138+
139+
rope = PartialRotaryEmbedding(
140+
min_timescale=1,
141+
max_timescale=10000,
142+
mesh=self.mesh,
143+
embedding_dims=head_dim,
144+
partial_rotary_factor=0.5,
145+
rngs=self.nnx_rng,
146+
cast_as_fprop_dtype=False,
147+
)
148+
149+
def get_attn(pos):
150+
y = rope(inputs, pos)
151+
return np.einsum("BSNH,BTNH->BSTN", y, y)
152+
153+
ref_attn = get_attn(positions)
154+
shifted_attn = get_attn(positions + 3)
155+
156+
np.testing.assert_allclose(
157+
ref_attn,
158+
shifted_attn,
159+
rtol=1e-6,
160+
atol=1e-6,
161+
err_msg="PartialRotaryEmbedding attention should be shift-invariant.",
162+
)
163+
164+
165+
if __name__ == "__main__":
166+
unittest.main()

0 commit comments

Comments
 (0)