Skip to content

Commit edc0cd9

Browse files
committed
rope changes
1 parent efc2681 commit edc0cd9

1 file changed

Lines changed: 14 additions & 13 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,30 @@
2929

3030
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
3131
"""
32-
Applies Interleaved RoPE to input x.
33-
Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
Applies Interleaved RoPE to input x by computing rotation directly on slices.
33+
Avoids creating the full `x_rotated` intermediate tensor.
3434
3535
Args:
3636
x: Input tensor [..., D]
3737
freqs: Tuple of (cos, sin), broadcasting to [..., D]
3838
"""
3939
cos, sin = freqs
4040

41-
# 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
42-
# This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
43-
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
41+
# Slice even and odd indices (No copy, just metadata views)
42+
x_even = x[..., 0::2]
43+
x_odd = x[..., 1::2]
4444

45-
# 2. Split into components
46-
# x_real = x[..., 0], x_imag = x[..., 1]
47-
x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1]
45+
cos_even = cos[..., 0::2]
46+
sin_even = sin[..., 0::2]
47+
cos_odd = cos[..., 1::2]
48+
sin_odd = sin[..., 1::2]
4849

49-
# 3. Rotate [-x2, x1]
50-
# Corresponds to "stack((-t2, t1))"
51-
x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape)
50+
# Direct math (Fuses perfectly on TPU)
51+
out_even = x_even.astype(jnp.float32) * cos_even - x_odd.astype(jnp.float32) * sin_even
52+
out_odd = x_odd.astype(jnp.float32) * cos_odd + x_even.astype(jnp.float32) * sin_odd
5253

53-
# 4. Apply frequencies (Float32 for stability)
54-
out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin
54+
# Interleave the results back
55+
out = jnp.stack([out_even, out_odd], axis=-1).reshape(*x.shape)
5556

5657
return out.astype(x.dtype)
5758

0 commit comments

Comments
 (0)