|
29 | 29 |
|
30 | 30 | def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array: |
31 | 31 | """ |
32 | | - Applies Interleaved RoPE to input x by computing rotation directly on slices. |
33 | | - Avoids creating the full `x_rotated` intermediate tensor. |
| 32 | + Applies Interleaved RoPE to input x. |
| 33 | + Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1]. |
34 | 34 |
|
35 | 35 | Args: |
36 | 36 | x: Input tensor [..., D] |
37 | 37 | freqs: Tuple of (cos, sin), broadcasting to [..., D] |
38 | 38 | """ |
39 | 39 | cos, sin = freqs |
40 | 40 |
|
41 | | - # Slice even and odd indices (No copy, just metadata views) |
42 | | - x_even = x[..., 0::2] |
43 | | - x_odd = x[..., 1::2] |
| 41 | + # 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2] |
| 42 | + # This corresponds to "rearrange(..., (d r) -> ... d r, r=2)" |
| 43 | + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) |
44 | 44 |
|
45 | | - cos_even = cos[..., 0::2] |
46 | | - sin_even = sin[..., 0::2] |
47 | | - cos_odd = cos[..., 1::2] |
48 | | - sin_odd = sin[..., 1::2] |
| 45 | + # 2. Split into components |
| 46 | + # x_real = x[..., 0], x_imag = x[..., 1] |
| 47 | + x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1] |
49 | 48 |
|
50 | | - # Direct math (Fuses perfectly on TPU) |
51 | | - out_even = x_even.astype(jnp.float32) * cos_even - x_odd.astype(jnp.float32) * sin_even |
52 | | - out_odd = x_odd.astype(jnp.float32) * cos_odd + x_even.astype(jnp.float32) * sin_odd |
| 49 | + # 3. Rotate [-x2, x1] |
| 50 | + # Corresponds to "stack((-t2, t1))" |
| 51 | + x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape) |
53 | 52 |
|
54 | | - # Interleave the results back |
55 | | - out = jnp.stack([out_even, out_odd], axis=-1).reshape(*x.shape) |
| 53 | + # 4. Apply frequencies (Float32 for stability) |
| 54 | + out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin |
56 | 55 |
|
57 | 56 | return out.astype(x.dtype) |
58 | 57 |
|
|
0 commit comments