|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Unit tests for the partial rotary position embedding layer. |
| 16 | +
|
| 17 | +The new PartialRotaryEmbedding class is a thin wrapper around |
| 18 | +`RotaryEmbedding` that applies RoPE only to the first fraction of the |
| 19 | +hidden dimensions. The tests below exercise the half/fully-rotated |
| 20 | +cases and verify basic shift invariance in the same style used by our |
| 21 | +existing rotary unit tests. |
| 22 | +""" |
| 23 | + |
| 24 | +import sys |
| 25 | +import unittest |
| 26 | + |
| 27 | +import jax |
| 28 | +import jax.numpy as jnp |
| 29 | +from jax.sharding import Mesh |
| 30 | +from flax import nnx |
| 31 | +import numpy as np |
| 32 | + |
| 33 | +from maxtext.layers.embeddings import PartialRotaryEmbedding, RotaryEmbedding |
| 34 | +from maxtext.configs import pyconfig |
| 35 | +from maxtext.utils import maxtext_utils |
| 36 | +from tests.utils.test_helpers import get_test_config_path |
| 37 | + |
| 38 | + |
| 39 | +class PartialRotaryEmbeddingTest(unittest.TestCase): |
| 40 | + """Tests for the PartialRotaryEmbedding layer.""" |
| 41 | + |
| 42 | + def setUp(self): |
| 43 | + super().setUp() |
| 44 | + # build a simple config and mesh like other embedding tests |
| 45 | + self.cfg = pyconfig.initialize( |
| 46 | + [sys.argv[0], get_test_config_path()], |
| 47 | + run_name="test_embeddings", |
| 48 | + enable_checkpointing=False, |
| 49 | + ) |
| 50 | + devices_array = maxtext_utils.create_device_mesh(self.cfg) |
| 51 | + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) |
| 52 | + self.nnx_rng = nnx.Rngs(params=0) |
| 53 | + |
| 54 | + def test_partial_rotary_half(self): |
| 55 | + """The first half of the hidden dim should be rotated, the rest passthrough.""" |
| 56 | + batch_size, seq_len, num_heads, head_dim = 2, 16, 4, 8 |
| 57 | + inputs = jax.random.normal(jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim)) |
| 58 | + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, seq_len) |
| 59 | + |
| 60 | + rope_half = PartialRotaryEmbedding( |
| 61 | + min_timescale=1, |
| 62 | + max_timescale=10000, |
| 63 | + mesh=self.mesh, |
| 64 | + embedding_dims=head_dim, |
| 65 | + partial_rotary_factor=0.5, |
| 66 | + rngs=self.nnx_rng, |
| 67 | + cast_as_fprop_dtype=False, |
| 68 | + ) |
| 69 | + |
| 70 | + y_half = rope_half(inputs, positions) |
| 71 | + |
| 72 | + rotary_dim = head_dim // 2 |
| 73 | + inputs_rot, inputs_pass = inputs[..., :rotary_dim], inputs[..., rotary_dim:] |
| 74 | + |
| 75 | + rope_full_for_rot_part = RotaryEmbedding( |
| 76 | + min_timescale=1, |
| 77 | + max_timescale=10000, |
| 78 | + mesh=self.mesh, |
| 79 | + embedding_dims=rotary_dim, |
| 80 | + rngs=self.nnx_rng, |
| 81 | + cast_as_fprop_dtype=False, |
| 82 | + ) |
| 83 | + y_rot_expected = rope_full_for_rot_part(inputs_rot, positions) |
| 84 | + |
| 85 | + np.testing.assert_allclose( |
| 86 | + y_half[..., :rotary_dim], |
| 87 | + y_rot_expected, |
| 88 | + rtol=1e-6, |
| 89 | + atol=1e-6, |
| 90 | + err_msg="First fraction should be rotated.", |
| 91 | + ) |
| 92 | + np.testing.assert_allclose( |
| 93 | + y_half[..., rotary_dim:], |
| 94 | + inputs_pass, |
| 95 | + rtol=1e-6, |
| 96 | + atol=1e-6, |
| 97 | + err_msg="Remaining dims should pass through.", |
| 98 | + ) |
| 99 | + |
| 100 | + def test_partial_rotary_full(self): |
| 101 | + """A partial factor of 1.0 is equivalent to the base rotary embedding.""" |
| 102 | + batch_size, seq_len, num_heads, head_dim = 1, 4, 4, 8 |
| 103 | + inputs = jax.random.normal(jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim)) |
| 104 | + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, seq_len) |
| 105 | + |
| 106 | + rope_partial = PartialRotaryEmbedding( |
| 107 | + min_timescale=1, |
| 108 | + max_timescale=10000, |
| 109 | + mesh=self.mesh, |
| 110 | + embedding_dims=head_dim, |
| 111 | + partial_rotary_factor=1.0, |
| 112 | + rngs=self.nnx_rng, |
| 113 | + ) |
| 114 | + y_partial = rope_partial(inputs, positions) |
| 115 | + |
| 116 | + rope_full = RotaryEmbedding( |
| 117 | + min_timescale=1, |
| 118 | + max_timescale=10000, |
| 119 | + mesh=self.mesh, |
| 120 | + embedding_dims=head_dim, |
| 121 | + rngs=self.nnx_rng, |
| 122 | + ) |
| 123 | + y_full = rope_full(inputs, positions) |
| 124 | + |
| 125 | + np.testing.assert_allclose( |
| 126 | + y_partial, |
| 127 | + y_full, |
| 128 | + rtol=1e-6, |
| 129 | + atol=1e-6, |
| 130 | + err_msg="PartialRotaryEmbedding with factor=1 should equal full rotary.", |
| 131 | + ) |
| 132 | + |
| 133 | + def test_shift_invariance(self): |
| 134 | + """Verify that rotary attention computed from partial embedding is shift invariant.""" |
| 135 | + batch_size, seq_len, num_heads, head_dim = 1, 20, 4, 8 |
| 136 | + inputs = jax.random.normal(jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim)) |
| 137 | + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, seq_len) |
| 138 | + |
| 139 | + rope = PartialRotaryEmbedding( |
| 140 | + min_timescale=1, |
| 141 | + max_timescale=10000, |
| 142 | + mesh=self.mesh, |
| 143 | + embedding_dims=head_dim, |
| 144 | + partial_rotary_factor=0.5, |
| 145 | + rngs=self.nnx_rng, |
| 146 | + cast_as_fprop_dtype=False, |
| 147 | + ) |
| 148 | + |
| 149 | + def get_attn(pos): |
| 150 | + y = rope(inputs, pos) |
| 151 | + return np.einsum("BSNH,BTNH->BSTN", y, y) |
| 152 | + |
| 153 | + ref_attn = get_attn(positions) |
| 154 | + shifted_attn = get_attn(positions + 3) |
| 155 | + |
| 156 | + np.testing.assert_allclose( |
| 157 | + ref_attn, |
| 158 | + shifted_attn, |
| 159 | + rtol=1e-6, |
| 160 | + atol=1e-6, |
| 161 | + err_msg="PartialRotaryEmbedding attention should be shift-invariant.", |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +if __name__ == "__main__": |
| 166 | + unittest.main() |
0 commit comments