Skip to content

Commit a86a81f

Browse files
committed
restored
1 parent edc0cd9 commit a86a81f

1 file changed

Lines changed: 13 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,29 @@
2929

3030
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
3131
"""
32-
Applies Interleaved RoPE to input x by computing rotation directly on slices.
33-
Avoids creating the full `x_rotated` intermediate tensor.
32+
Applies Interleaved RoPE to input x.
33+
Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
3434
3535
Args:
3636
x: Input tensor [..., D]
3737
freqs: Tuple of (cos, sin), broadcasting to [..., D]
3838
"""
3939
cos, sin = freqs
4040

41-
# Slice even and odd indices (No copy, just metadata views)
42-
x_even = x[..., 0::2]
43-
x_odd = x[..., 1::2]
41+
# 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
42+
# This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
43+
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
4444

45-
cos_even = cos[..., 0::2]
46-
sin_even = sin[..., 0::2]
47-
cos_odd = cos[..., 1::2]
48-
sin_odd = sin[..., 1::2]
45+
# 2. Split into components
46+
# x_real = x[..., 0], x_imag = x[..., 1]
47+
x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1]
4948

50-
# Direct math (Fuses perfectly on TPU)
51-
out_even = x_even.astype(jnp.float32) * cos_even - x_odd.astype(jnp.float32) * sin_even
52-
out_odd = x_odd.astype(jnp.float32) * cos_odd + x_even.astype(jnp.float32) * sin_odd
49+
# 3. Rotate [-x2, x1]
50+
# Corresponds to "stack((-t2, t1))"
51+
x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape)
5352

54-
# Interleave the results back
55-
out = jnp.stack([out_even, out_odd], axis=-1).reshape(*x.shape)
53+
# 4. Apply frequencies (Float32 for stability)
54+
out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin
5655

5756
return out.astype(x.dtype)
5857

0 commit comments

Comments
 (0)