Skip to content

Commit 554c7bd

Browse files
committed
Add LTX2 attention
1 parent ced76d0 commit 554c7bd

2 files changed

Lines changed: 649 additions & 0 deletions

File tree

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
from typing import Any, Dict, Optional, Tuple, Union
17+
from flax import nnx
18+
import jax
19+
import jax.numpy as jnp
20+
from ... import common_types
21+
from ..attention_flax import NNXAttentionOp
22+
23+
Array = common_types.Array
24+
Mesh = common_types.Mesh
25+
DType = common_types.DType
26+
27+
28+
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
29+
"""
30+
Applies Interleaved RoPE to input x.
31+
Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
33+
Args:
34+
x: Input tensor [B, S, H, D]
35+
freqs: Tuple of (cos, sin), broadcasting to [B, S, 1, D] or [B, S, H, D]
36+
"""
37+
cos, sin = freqs
38+
39+
# 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
40+
# This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
41+
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
42+
43+
# 2. Split into components
44+
# x_real = x[..., 0], x_imag = x[..., 1]
45+
x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1]
46+
47+
# 3. Rotate [-x2, x1]
48+
# Corresponds to "stack((-t2, t1))"
49+
x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape)
50+
51+
# 4. Apply frequencies (Float32 for stability)
52+
out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin
53+
54+
return out.astype(x.dtype)
55+
56+
57+
class LTX2RotaryPosEmbed(nnx.Module):
58+
"""
59+
RoPE implementation that accepts pre-computed position IDs.
60+
Allows flexibility for 3D (Video) vs 1D (Audio/Temporal) usage.
61+
"""
62+
def __init__(self, dim: int, theta: float = 10000.0):
63+
self.dim = dim
64+
self.theta = theta
65+
66+
def __call__(self, ids: Array) -> Tuple[Array, Array]:
67+
"""
68+
Generates RoPE frequencies.
69+
Args:
70+
ids: [B, S, Num_Axes]
71+
- For Video 3D: Num_Axes=3 (Time, Height, Width)
72+
- For Audio 1D: Num_Axes=1 (Time)
73+
Returns:
74+
cos, sin: [B, S, 1, Dim] (Ready for broadcasting across heads)
75+
"""
76+
num_axes = ids.shape[-1]
77+
dim_per_axis = self.dim // num_axes
78+
79+
# Standard RoPE frequencies
80+
freq_indices = jnp.arange(0, dim_per_axis, 2, dtype=jnp.float32)
81+
inv_freq = 1.0 / (self.theta ** (freq_indices / dim_per_axis))
82+
83+
freqs_list = []
84+
for i in range(num_axes):
85+
axis_pos = ids[..., i]
86+
# Outer product: [B, S] x [D_axis/2] -> [B, S, D_axis/2]
87+
freqs = jnp.einsum('bs,d->bsd', axis_pos, inv_freq)
88+
freqs_list.append(freqs)
89+
90+
# Concatenate axes -> [B, S, D/2]
91+
emb = jnp.concatenate(freqs_list, axis=-1)
92+
93+
cos = jnp.cos(emb)
94+
sin = jnp.sin(emb)
95+
96+
# Repeat for Interleaved RoPE: [c1, c2] -> [c1, c1, c2, c2]
97+
cos = jnp.repeat(cos, 2, axis=-1)
98+
sin = jnp.repeat(sin, 2, axis=-1)
99+
100+
# Add head dim for broadcasting: [B, S, 1, Inner_Dim]
101+
return cos[:, :, None, :], sin[:, :, None, :]
102+
103+
104+
class LTX2Attention(nnx.Module):
105+
def __init__(
106+
self,
107+
query_dim: int,
108+
heads: int,
109+
dim_head: int,
110+
context_dim: Optional[int] = None,
111+
dropout: float = 0.0,
112+
bias: bool = True, # LTX-2 uses bias=True for projections
113+
out_bias: bool = True,
114+
rngs: nnx.Rngs = None,
115+
mesh: Mesh = None,
116+
eps: float = 1e-6,
117+
dtype: DType = jnp.float32,
118+
attention_kernel: str = "flash",
119+
):
120+
self.heads = heads
121+
self.dim_head = dim_head
122+
self.inner_dim = dim_head * heads
123+
self.dropout_rate = dropout
124+
125+
# 1. Projections
126+
self.to_q = nnx.Linear(query_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
127+
128+
# Handle Self vs Cross Attention input dims
129+
kv_dim = context_dim if context_dim is not None else query_dim
130+
self.to_k = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
131+
self.to_v = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
132+
133+
# 2. Normalization (Applied to full inner_dim, NOT per-head)
134+
self.norm_q = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
135+
self.norm_k = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
136+
137+
# 3. Output
138+
self.to_out = nnx.Linear(self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype)
139+
140+
if self.dropout_rate > 0:
141+
self.dropout_layer = nnx.Dropout(self.dropout_rate, rngs=rngs)
142+
else:
143+
self.dropout_layer = None
144+
145+
self.attention_op = NNXAttentionOp(
146+
mesh=mesh,
147+
attention_kernel=attention_kernel,
148+
scale=dim_head**-0.5,
149+
heads=heads,
150+
dim_head=dim_head,
151+
dtype=dtype,
152+
)
153+
154+
def _reshape_rope(self, rope_emb: Tuple[Array, Array]) -> Tuple[Array, Array]:
155+
"""Reshapes [B, S, 1, InnerDim] -> [B, S, Heads, DimHead] for broadcasting."""
156+
cos, sin = rope_emb
157+
# If tests pass already shaped tensors, return as is
158+
if cos.ndim == 4 and cos.shape[-2] == self.heads and cos.shape[-1] == self.dim_head:
159+
return cos, sin
160+
161+
# Reshape: [B, S, 1, H*D] -> [B, S, H, D]
162+
# We assume the last dimension is InnerDim = Heads * DimHead
163+
new_shape = cos.shape[:-2] + (self.heads, self.dim_head)
164+
return cos.reshape(new_shape), sin.reshape(new_shape)
165+
166+
def __call__(
167+
self,
168+
hidden_states: Array,
169+
encoder_hidden_states: Optional[Array] = None,
170+
attention_mask: Optional[Array] = None,
171+
rotary_emb: Optional[Tuple[Array, Array]] = None,
172+
k_rotary_emb: Optional[Tuple[Array, Array]] = None,
173+
) -> Array:
174+
175+
# Determine context (Self or Cross)
176+
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
177+
178+
# 1. Project
179+
query = self.to_q(hidden_states)
180+
key = self.to_k(context)
181+
value = self.to_v(context)
182+
183+
# 2. Norm (Full Inner Dimension)
184+
query = self.norm_q(query)
185+
key = self.norm_k(key)
186+
187+
# 3. Reshape to Heads [B, S, H, D]
188+
query = query.reshape(*query.shape[:-1], self.heads, self.dim_head)
189+
key = key.reshape(*key.shape[:-1], self.heads, self.dim_head)
190+
value = value.reshape(*value.shape[:-1], self.heads, self.dim_head)
191+
192+
# 4. Apply RoPE
193+
if rotary_emb is not None:
194+
# Reshape [1, Inner] -> [H, D]
195+
q_rope = self._reshape_rope(rotary_emb)
196+
query = apply_rotary_emb(query, q_rope)
197+
198+
# Key RoPE Logic
199+
if k_rotary_emb is not None:
200+
# Explicit Key RoPE (e.g. Cross-Modal)
201+
k_rope = self._reshape_rope(k_rotary_emb)
202+
key = apply_rotary_emb(key, k_rope)
203+
elif encoder_hidden_states is None:
204+
# Self-Attention: Re-use q_rope
205+
key = apply_rotary_emb(key, q_rope)
206+
207+
# 5. Flatten back for AttentionOp [B, S, H*D]
208+
# NNXAttentionOp expects flattened input for flash kernel
209+
query = query.reshape(*query.shape[:-2], self.inner_dim)
210+
key = key.reshape(*key.shape[:-2], self.inner_dim)
211+
value = value.reshape(*value.shape[:-2], self.inner_dim)
212+
213+
# 6. Attention
214+
attn_output = self.attention_op.apply_attention(
215+
query=query, key=key, value=value, attention_mask=attention_mask
216+
)
217+
218+
# 7. Output Projection
219+
hidden_states = self.to_out(attn_output)
220+
221+
if self.dropout_layer is not None:
222+
hidden_states = self.dropout_layer(hidden_states)
223+
224+
return hidden_states

0 commit comments

Comments
 (0)