|
29 | 29 |
|
30 | 30 | def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array: |
31 | 31 | """ |
32 | | - Applies Interleaved RoPE to input x. |
33 | | - Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1]. |
| 32 | + Applies Interleaved RoPE to input x by computing rotation directly on slices. |
| 33 | + Avoids creating the full `x_rotated` intermediate tensor. |
34 | 34 |
|
35 | 35 | Args: |
36 | 36 | x: Input tensor [..., D] |
37 | 37 | freqs: Tuple of (cos, sin), broadcasting to [..., D] |
38 | 38 | """ |
39 | 39 | cos, sin = freqs |
40 | 40 |
|
41 | | - # 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2] |
42 | | - # This corresponds to "rearrange(..., (d r) -> ... d r, r=2)" |
43 | | - x_reshaped = x.reshape(*x.shape[:-1], -1, 2) |
| 41 | + # Slice even and odd indices (No copy, just metadata views) |
| 42 | + x_even = x[..., 0::2] |
| 43 | + x_odd = x[..., 1::2] |
44 | 44 |
|
45 | | - # 2. Split into components |
46 | | - # x_real = x[..., 0], x_imag = x[..., 1] |
47 | | - x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1] |
| 45 | + cos_even = cos[..., 0::2] |
| 46 | + sin_even = sin[..., 0::2] |
| 47 | + cos_odd = cos[..., 1::2] |
| 48 | + sin_odd = sin[..., 1::2] |
48 | 49 |
|
49 | | - # 3. Rotate [-x2, x1] |
50 | | - # Corresponds to "stack((-t2, t1))" |
51 | | - x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape) |
| 50 | + # Direct math (Fuses perfectly on TPU) |
| 51 | + out_even = x_even.astype(jnp.float32) * cos_even - x_odd.astype(jnp.float32) * sin_even |
| 52 | + out_odd = x_odd.astype(jnp.float32) * cos_odd + x_even.astype(jnp.float32) * sin_odd |
52 | 53 |
|
53 | | - # 4. Apply frequencies (Float32 for stability) |
54 | | - out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin |
| 54 | + # Interleave the results back |
| 55 | + out = jnp.stack([out_even, out_odd], axis=-1).reshape(*x.shape) |
55 | 56 |
|
56 | 57 | return out.astype(x.dtype) |
57 | 58 |
|
|
0 commit comments