Skip to content

Commit 00ae2a0

Browse files
committed
Add LTX2 Attention
1 parent 9997c59 commit 00ae2a0

2 files changed

Lines changed: 553 additions & 0 deletions

File tree

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
from typing import Any, Dict, Optional, Tuple, Union
17+
from flax import nnx
18+
import jax
19+
import jax.numpy as jnp
20+
from ... import common_types
21+
from ..attention_flax import NNXAttentionOp
22+
23+
Array = common_types.Array
24+
Mesh = common_types.Mesh
25+
DType = common_types.DType
26+
27+
28+
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
29+
"""
30+
Applies Interleaved RoPE to input x.
31+
Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
33+
Args:
34+
x: Input tensor [..., D]
35+
freqs: Tuple of (cos, sin), broadcasting to [..., D]
36+
"""
37+
cos, sin = freqs
38+
39+
# 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
40+
# This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
41+
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
42+
43+
# 2. Split into components
44+
# x_real = x[..., 0], x_imag = x[..., 1]
45+
x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1]
46+
47+
# 3. Rotate [-x2, x1]
48+
# Corresponds to "stack((-t2, t1))"
49+
x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape)
50+
51+
# 4. Apply frequencies (Float32 for stability)
52+
out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin
53+
54+
return out.astype(x.dtype)
55+
56+
57+
class LTX2RotaryPosEmbed(nnx.Module):
58+
"""
59+
RoPE implementation that accepts pre-computed position IDs.
60+
Allows flexibility for 3D (Video) vs 1D (Audio/Temporal) usage.
61+
"""
62+
def __init__(self, dim: int, theta: float = 10000.0):
63+
self.dim = dim
64+
self.theta = theta
65+
66+
def __call__(self, ids: Array) -> Tuple[Array, Array]:
67+
"""
68+
Generates RoPE frequencies.
69+
Args:
70+
ids: [B, S, Num_Axes]
71+
- For Video 3D: Num_Axes=3 (Time, Height, Width)
72+
- For Audio 1D: Num_Axes=1 (Time)
73+
Returns:
74+
cos, sin: [B, S, Dim]
75+
"""
76+
num_axes = ids.shape[-1]
77+
dim_per_axis = self.dim // num_axes
78+
79+
# Standard RoPE frequencies
80+
freq_indices = jnp.arange(0, dim_per_axis, 2, dtype=jnp.float32)
81+
inv_freq = 1.0 / (self.theta ** (freq_indices / dim_per_axis))
82+
83+
freqs_list = []
84+
for i in range(num_axes):
85+
axis_pos = ids[..., i]
86+
# Outer product: [B, S] x [D_axis/2] -> [B, S, D_axis/2]
87+
freqs = jnp.einsum('bs,d->bsd', axis_pos, inv_freq)
88+
freqs_list.append(freqs)
89+
90+
# Concatenate axes -> [B, S, D/2]
91+
emb = jnp.concatenate(freqs_list, axis=-1)
92+
93+
cos = jnp.cos(emb)
94+
sin = jnp.sin(emb)
95+
96+
# Repeat for Interleaved RoPE: [c1, c2] -> [c1, c1, c2, c2]
97+
cos = jnp.repeat(cos, 2, axis=-1)
98+
sin = jnp.repeat(sin, 2, axis=-1)
99+
100+
return cos, sin
101+
102+
103+
class LTX2Attention(nnx.Module):
104+
def __init__(
105+
self,
106+
query_dim: int,
107+
heads: int,
108+
dim_head: int,
109+
context_dim: Optional[int] = None,
110+
dropout: float = 0.0,
111+
bias: bool = True, # LTX-2 uses bias=True for projections
112+
out_bias: bool = True,
113+
rngs: nnx.Rngs = None,
114+
mesh: Mesh = None,
115+
eps: float = 1e-6,
116+
dtype: DType = jnp.float32,
117+
attention_kernel: str = "flash",
118+
):
119+
self.heads = heads
120+
self.dim_head = dim_head
121+
self.inner_dim = dim_head * heads
122+
self.dropout_rate = dropout
123+
124+
# 1. Projections
125+
self.to_q = nnx.Linear(query_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
126+
127+
# Handle Self vs Cross Attention input dims
128+
kv_dim = context_dim if context_dim is not None else query_dim
129+
self.to_k = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
130+
self.to_v = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
131+
132+
# 2. Normalization (Applied to full inner_dim, NOT per-head)
133+
self.norm_q = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
134+
self.norm_k = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
135+
136+
# 3. Output
137+
self.to_out = nnx.Linear(self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype)
138+
139+
if self.dropout_rate > 0:
140+
self.dropout_layer = nnx.Dropout(self.dropout_rate, rngs=rngs)
141+
else:
142+
self.dropout_layer = None
143+
144+
self.attention_op = NNXAttentionOp(
145+
mesh=mesh,
146+
attention_kernel=attention_kernel,
147+
scale=dim_head**-0.5,
148+
heads=heads,
149+
dim_head=dim_head,
150+
dtype=dtype,
151+
)
152+
153+
def __call__(
154+
self,
155+
hidden_states: Array,
156+
encoder_hidden_states: Optional[Array] = None,
157+
attention_mask: Optional[Array] = None,
158+
rotary_emb: Optional[Tuple[Array, Array]] = None,
159+
k_rotary_emb: Optional[Tuple[Array, Array]] = None,
160+
) -> Array:
161+
162+
# Determine context (Self or Cross)
163+
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
164+
165+
# 1. Project
166+
query = self.to_q(hidden_states)
167+
key = self.to_k(context)
168+
value = self.to_v(context)
169+
170+
# 2. Norm (Full Inner Dimension)
171+
query = self.norm_q(query)
172+
key = self.norm_k(key)
173+
174+
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
175+
# Frequencies are shape [B, S, InnerDim]
176+
if rotary_emb is not None:
177+
query = apply_rotary_emb(query, rotary_emb)
178+
if k_rotary_emb is not None:
179+
key = apply_rotary_emb(key, k_rotary_emb)
180+
elif encoder_hidden_states is None:
181+
key = apply_rotary_emb(key, rotary_emb)
182+
183+
# 4. Attention
184+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
185+
attn_output = self.attention_op.apply_attention(
186+
query=query, key=key, value=value, attention_mask=attention_mask
187+
)
188+
189+
# 7. Output Projection
190+
hidden_states = self.to_out(attn_output)
191+
192+
if self.dropout_layer is not None:
193+
hidden_states = self.dropout_layer(hidden_states)
194+
195+
return hidden_states

0 commit comments

Comments
 (0)