Skip to content

Commit 1f7888b

Browse files
committed
Attention
1 parent c127236 commit 1f7888b

2 files changed

Lines changed: 74 additions & 34 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,39 @@
2626

2727

2828
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
29-
"""Apply rotary embeddings to input x."""
29+
"""
30+
Applies Interleaved RoPE to input x.
31+
Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
33+
Args:
34+
x: Input tensor [B, S, H, D]
35+
freqs: Tuple of (cos, sin), broadcasting to [B, S, 1, D] or [B, S, H, D]
36+
"""
3037
cos, sin = freqs
31-
# x shape: [B, S, H, D]
32-
# cos/sin shape: [B, S, 1, D]
3338

34-
# Standard interleaved rotation: [-x2, x1]
39+
# 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
40+
# This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
3541
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
42+
43+
# 2. Split into components
44+
# x_real = x[..., 0], x_imag = x[..., 1]
3645
x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1]
3746

47+
# 3. Rotate [-x2, x1]
48+
# Corresponds to "stack((-t2, t1))"
3849
x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape)
3950

51+
# 4. Apply frequencies (Float32 for stability)
4052
out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin
53+
4154
return out.astype(x.dtype)
4255

4356

4457
class LTX2RotaryPosEmbed(nnx.Module):
58+
"""
59+
RoPE implementation that accepts pre-computed position IDs.
60+
Allows flexibility for 3D (Video) vs 1D (Audio/Temporal) usage.
61+
"""
4562
def __init__(self, dim: int, theta: float = 10000.0):
4663
self.dim = dim
4764
self.theta = theta
@@ -51,31 +68,36 @@ def __call__(self, ids: Array) -> Tuple[Array, Array]:
5168
Generates RoPE frequencies.
5269
Args:
5370
ids: [B, S, Num_Axes]
54-
- For Video 3D: Num_Axes=3 (T, H, W)
55-
- For Audio 1D: Num_Axes=1 (T)
56-
- For Temporal-Only: Pass ids[:, :, 0:1] (Slice to keep only Time)
71+
- For Video 3D: Num_Axes=3 (Time, Height, Width)
72+
- For Audio 1D: Num_Axes=1 (Time)
73+
Returns:
74+
cos, sin: [B, S, 1, Dim] (Ready for broadcasting across heads)
5775
"""
5876
num_axes = ids.shape[-1]
5977
dim_per_axis = self.dim // num_axes
6078

79+
# Standard RoPE frequencies
6180
freq_indices = jnp.arange(0, dim_per_axis, 2, dtype=jnp.float32)
6281
inv_freq = 1.0 / (self.theta ** (freq_indices / dim_per_axis))
6382

6483
freqs_list = []
6584
for i in range(num_axes):
6685
axis_pos = ids[..., i]
86+
# Outer product: [B, S] x [D_axis/2] -> [B, S, D_axis/2]
6787
freqs = jnp.einsum('bs,d->bsd', axis_pos, inv_freq)
6888
freqs_list.append(freqs)
6989

90+
# Concatenate axes -> [B, S, D/2]
7091
emb = jnp.concatenate(freqs_list, axis=-1)
7192

7293
cos = jnp.cos(emb)
7394
sin = jnp.sin(emb)
7495

96+
# Repeat for Interleaved RoPE: [c1, c2] -> [c1, c1, c2, c2]
7597
cos = jnp.repeat(cos, 2, axis=-1)
7698
sin = jnp.repeat(sin, 2, axis=-1)
7799

78-
# Add head dim: [B, S, 1, D]
100+
# Add head dim for broadcasting: [B, S, 1, Inner_Dim]
79101
return cos[:, :, None, :], sin[:, :, None, :]
80102

81103

@@ -87,7 +109,7 @@ def __init__(
87109
dim_head: int,
88110
context_dim: Optional[int] = None,
89111
dropout: float = 0.0,
90-
bias: bool = True,
112+
bias: bool = True, # LTX-2 uses bias=True for projections
91113
out_bias: bool = True,
92114
rngs: nnx.Rngs = None,
93115
mesh: Mesh = None,
@@ -100,16 +122,19 @@ def __init__(
100122
self.inner_dim = dim_head * heads
101123
self.dropout_rate = dropout
102124

125+
# 1. Projections
103126
self.to_q = nnx.Linear(query_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
104127

128+
# Handle Self vs Cross Attention input dims
105129
kv_dim = context_dim if context_dim is not None else query_dim
106130
self.to_k = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
107131
self.to_v = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
108132

109-
# Norm over full inner_dim (Fix #2)
133+
# 2. Normalization (Applied to full inner_dim, NOT per-head)
110134
self.norm_q = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
111135
self.norm_k = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
112136

137+
# 3. Output
113138
self.to_out = nnx.Linear(self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype)
114139

115140
if self.dropout_rate > 0:
@@ -126,6 +151,18 @@ def __init__(
126151
dtype=dtype,
127152
)
128153

154+
def _reshape_rope(self, rope_emb: Tuple[Array, Array]) -> Tuple[Array, Array]:
155+
"""Reshapes [B, S, 1, InnerDim] -> [B, S, Heads, DimHead] for broadcasting."""
156+
cos, sin = rope_emb
157+
# If tests pass already shaped tensors, return as is
158+
if cos.ndim == 4 and cos.shape[-2] == self.heads and cos.shape[-1] == self.dim_head:
159+
return cos, sin
160+
161+
# Reshape: [B, S, 1, H*D] -> [B, S, H, D]
162+
# We assume the last dimension is InnerDim = Heads * DimHead
163+
new_shape = cos.shape[:-2] + (self.heads, self.dim_head)
164+
return cos.reshape(new_shape), sin.reshape(new_shape)
165+
129166
def __call__(
130167
self,
131168
hidden_states: Array,
@@ -135,33 +172,40 @@ def __call__(
135172
k_rotary_emb: Optional[Tuple[Array, Array]] = None,
136173
) -> Array:
137174

175+
# Determine context (Self or Cross)
138176
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
139177

140178
# 1. Project
141179
query = self.to_q(hidden_states)
142180
key = self.to_k(context)
143181
value = self.to_v(context)
144182

145-
# 2. Norm
183+
# 2. Norm (Full Inner Dimension)
146184
query = self.norm_q(query)
147185
key = self.norm_k(key)
148186

149-
# 3. Reshape for RoPE [B, S, H, D]
187+
# 3. Reshape to Heads [B, S, H, D]
150188
query = query.reshape(*query.shape[:-1], self.heads, self.dim_head)
151189
key = key.reshape(*key.shape[:-1], self.heads, self.dim_head)
152190
value = value.reshape(*value.shape[:-1], self.heads, self.dim_head)
153191

154192
# 4. Apply RoPE
155193
if rotary_emb is not None:
156-
query = apply_rotary_emb(query, rotary_emb)
194+
# Reshape [1, Inner] -> [H, D]
195+
q_rope = self._reshape_rope(rotary_emb)
196+
query = apply_rotary_emb(query, q_rope)
157197

198+
# Key RoPE Logic
158199
if k_rotary_emb is not None:
159-
key = apply_rotary_emb(key, k_rotary_emb)
160-
elif encoder_hidden_states is None: # Self-Attention
161-
key = apply_rotary_emb(key, rotary_emb)
162-
163-
# 5. Flatten back for AttentionOp (Fix #1)
164-
# [B, S, H, D] -> [B, S, H*D]
200+
# Explicit Key RoPE (e.g. Cross-Modal)
201+
k_rope = self._reshape_rope(k_rotary_emb)
202+
key = apply_rotary_emb(key, k_rope)
203+
elif encoder_hidden_states is None:
204+
# Self-Attention: Re-use q_rope
205+
key = apply_rotary_emb(key, q_rope)
206+
207+
# 5. Flatten back for AttentionOp [B, S, H*D]
208+
# NNXAttentionOp expects flattened input for flash kernel
165209
query = query.reshape(*query.shape[:-2], self.inner_dim)
166210
key = key.reshape(*key.shape[:-2], self.inner_dim)
167211
value = value.reshape(*value.shape[:-2], self.inner_dim)
@@ -171,10 +215,10 @@ def __call__(
171215
query=query, key=key, value=value, attention_mask=attention_mask
172216
)
173217

174-
# attn_output is already [B, S, H*D], no reshape needed before output proj
218+
# 7. Output Projection
175219
hidden_states = self.to_out(attn_output)
176220

177221
if self.dropout_layer is not None:
178222
hidden_states = self.dropout_layer(hidden_states)
179223

180-
return hidden_states
224+
return hidden_states

src/maxdiffusion/tests/attention_ltx2_test.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,14 @@
1515

1616
import unittest
1717
from flax import nnx
18-
import jax
1918
import jax.numpy as jnp
20-
# Adjust this import to match your file structure
2119
from ..models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2220

2321

2422
class LTX2AttentionTest(unittest.TestCase):
2523

2624
def test_rope_video_shapes_3d(self):
2725
"""Test 3D RoPE generation for Video (Time, Height, Width)."""
28-
dim = 64
2926
# LTX-2 splits dim across axes. 60 is divisible by 3 (20 per axis).
3027
dim = 60
3128
rope = LTX2RotaryPosEmbed(dim=dim, theta=10000.0)
@@ -40,7 +37,6 @@ def test_rope_video_shapes_3d(self):
4037
cos, sin = rope(ids)
4138

4239
# Expected output: [B, S, 1, D] (The 1 is for broadcasting across heads)
43-
# This confirms the RoPE module outputs the correct broadcasting shape.
4440
self.assertEqual(cos.shape, (batch_size, seq_len, 1, dim))
4541
self.assertEqual(sin.shape, (batch_size, seq_len, 1, dim))
4642

@@ -51,34 +47,30 @@ def test_rope_audio_shapes_1d(self):
5147

5248
batch_size = 2
5349
seq_len = 20
54-
55-
# Create dummy position IDs for [Time]
5650
# Shape: [B, S, 1]
5751
ids = jnp.zeros((batch_size, seq_len, 1), dtype=jnp.float32)
5852

5953
cos, sin = rope(ids)
60-
61-
# Expected output: [B, S, 1, D]
6254
self.assertEqual(cos.shape, (batch_size, seq_len, 1, dim))
63-
self.assertEqual(sin.shape, (batch_size, seq_len, 1, dim))
6455

6556
def test_self_attention_forward(self):
6657
"""Test basic Self-Attention forward pass (Video <-> Video)."""
6758
dim = 64
6859
heads = 4
6960
dim_head = 16 # inner_dim = 64
7061

62+
# Use dot_product to avoid TPU mesh requirements during unit testing
7163
model = LTX2Attention(
7264
query_dim=dim,
7365
heads=heads,
7466
dim_head=dim_head,
7567
rngs=nnx.Rngs(0),
68+
attention_kernel="dot_product"
7669
)
7770

7871
# Standard input [B, S, D]
7972
x = jnp.ones((1, 16, dim))
8073

81-
# Forward
8274
out = model(hidden_states=x)
8375

8476
self.assertEqual(out.shape, (1, 16, dim))
@@ -96,6 +88,7 @@ def test_cross_attention_forward(self):
9688
dim_head=dim_head,
9789
context_dim=context_dim, # Triggers cross-attention init
9890
rngs=nnx.Rngs(0),
91+
attention_kernel="dot_product"
9992
)
10093

10194
x = jnp.ones((1, 16, query_dim)) # Video
@@ -116,12 +109,14 @@ def test_attention_with_rope_integration(self):
116109
heads=heads,
117110
dim_head=dim_head,
118111
rngs=nnx.Rngs(0),
112+
attention_kernel="dot_product"
119113
)
120114

121115
x = jnp.ones((2, 8, dim))
122116

123-
# Create manual RoPE embeddings matching the output of LTX2RotaryPosEmbed
124-
# Shape: [B, S, 1, inner_dim]
117+
# Create manual RoPE embeddings matching output of LTX2RotaryPosEmbed
118+
# Shape: [B, S, 1, Inner_Dim]
119+
# Inner_Dim = 64
125120
cos = jnp.ones((2, 8, 1, 64))
126121
sin = jnp.ones((2, 8, 1, 64))
127122
rope_emb = (cos, sin)
@@ -144,6 +139,7 @@ def test_cross_modal_temporal_rope(self):
144139
dim_head=dim_head,
145140
context_dim=query_dim,
146141
rngs=nnx.Rngs(0),
142+
attention_kernel="dot_product"
147143
)
148144

149145
x = jnp.ones((1, 16, query_dim)) # Video
@@ -153,7 +149,7 @@ def test_cross_modal_temporal_rope(self):
153149
video_ids_3d = jnp.zeros((1, 16, 3))
154150

155151
# 2. Extract ONLY Time axis for Cross-Attention RoPE
156-
# The pipeline must do this slicing: ids[:, :, 0:1]
152+
# This simulates the logic that will live in the pipeline/model loop
157153
video_ids_temporal = video_ids_3d[..., 0:1] # Shape [1, 16, 1]
158154

159155
# 3. Generate 1D IDs for Audio

0 commit comments

Comments
 (0)