Skip to content

Commit 529aafc

Browse files
committed
Attention Tests
1 parent 8e9ff4d commit 529aafc

1 file changed

Lines changed: 144 additions & 67 deletions

File tree

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 144 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import jax.numpy as jnp
2222
from flax import nnx
2323
import pandas as pd
24+
from jax.sharding import Mesh
2425

25-
# Set JAX to use float32 for precision checks
26+
# Set JAX to use float32 for higher precision checks
2627
jax.config.update("jax_default_matmul_precision", "float32")
2728

2829
# ==========================================
@@ -44,11 +45,11 @@ def forward(self, ids):
4445
num_axes = ids.shape[-1]
4546
dim_per_axis = self.dim // num_axes
4647

47-
freqs_list = []
4848
# Standard RoPE frequencies: theta^(-2i/d)
4949
freq_indices = torch.arange(0, dim_per_axis, 2, dtype=torch.float32)
5050
inv_freq = 1.0 / (self.theta ** (freq_indices / dim_per_axis))
5151

52+
freqs_list = []
5253
for i in range(num_axes):
5354
axis_pos = ids[..., i] # [B, S]
5455
# Outer product: [B, S, 1] * [1, 1, D/2] -> [B, S, D/2]
@@ -65,7 +66,7 @@ def forward(self, ids):
6566
cos = torch.repeat_interleave(cos, 2, dim=-1)
6667
sin = torch.repeat_interleave(sin, 2, dim=-1)
6768

68-
# Add head dim: [B, S, 1, D]
69+
# Add head dim for broadcasting: [B, S, 1, D]
6970
return cos.unsqueeze(2), sin.unsqueeze(2)
7071

7172

@@ -77,7 +78,7 @@ def apply_rotary_emb_pt(x, cos, sin):
7778
x1, x2 = x_reshaped.unbind(-1)
7879
x_rotated = torch.stack((-x2, x1), dim=-1).view(b, h, s, d)
7980

80-
# Cast to float32 for rotation parity
81+
# Cast to float32 for rotation parity with JAX
8182
orig_dtype = x.dtype
8283
x_f32 = x.to(torch.float32)
8384
rot_f32 = x_rotated.to(torch.float32)
@@ -108,7 +109,7 @@ def __init__(self, query_dim, context_dim, heads, dim_head):
108109
torch.nn.Identity()
109110
)
110111

111-
def forward(self, x, context=None, q_rope=None, k_rope=None):
112+
def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
112113
q = self.to_q(x)
113114
ctx = x if context is None else context
114115
k = self.to_k(ctx)
@@ -133,15 +134,19 @@ def forward(self, x, context=None, q_rope=None, k_rope=None):
133134
k_cos, k_sin = k_rope
134135
k_h = apply_rotary_emb_pt(k_h, k_cos, k_sin)
135136

136-
out = torch.nn.functional.scaled_dot_product_attention(q_h, k_h, v_h, dropout_p=0.0)
137+
# PyTorch Attention expects mask in [B, H, S, S] or additive
138+
out = torch.nn.functional.scaled_dot_product_attention(
139+
q_h, k_h, v_h, attn_mask=mask, dropout_p=0.0
140+
)
141+
137142
out = out.transpose(1, 2).reshape(b, s_q, -1)
138143

139-
return self.to_out(out)
144+
return self.to_out(out), (q, k, v, q_norm, k_norm, out)
140145

141146
# ==========================================
142147
# 2. JAX Imports
143148
# ==========================================
144-
from ..models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
149+
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
145150

146151
class LTX2AttentionTest(unittest.TestCase):
147152

@@ -192,48 +197,54 @@ def copy_norm(jax_layer, pt_layer):
192197
return pt_model, jax_model
193198

194199
# ------------------------------------------
195-
# 1. RoPE Frequency Parity Test
200+
# 1. Output Shape Tests
196201
# ------------------------------------------
197-
def test_rope_frequency_parity(self):
198-
"""
199-
Verifies that LTX2RotaryPosEmbed (JAX) generates the EXACT same
200-
frequencies as the PyTorch reference for a given input ID set.
201-
"""
202-
dim = 60 # Divisible by 3 for 3D test
202+
def test_shapes(self):
203+
"""Verifies JAX model handles Video (3D) and Audio (1D) shapes."""
204+
model = LTX2Attention(64, 4, 16, 64, rngs=self.rng, attention_kernel="dot_product")
205+
206+
# Video: [B, S, D]
207+
x_vid = jnp.zeros((1, 128, 64))
208+
out_vid = model(x_vid)
209+
self.assertEqual(out_vid.shape, (1, 128, 64))
203210

211+
# Audio Cross-Attn: [B, S_vid, D] -> [B, S_aud, D]
212+
x_aud = jnp.zeros((1, 32, 64))
213+
out_cross = model(x_vid, encoder_hidden_states=x_aud)
214+
self.assertEqual(out_cross.shape, (1, 128, 64))
215+
print("\n[PASS] Shape Tests Passed.")
216+
217+
# ------------------------------------------
218+
# 2. RoPE Frequency Parity
219+
# ------------------------------------------
220+
def test_rope_frequency_parity(self):
221+
"""Verifies JAX RoPE Frequencies match PyTorch."""
222+
dim = 60
204223
rope_pt = PytorchLTX2RotaryPosEmbed(dim=dim)
205224
rope_jax = LTX2RotaryPosEmbed(dim=dim)
206225

207-
# Create random IDs [B, S, 3]
208226
np_ids = np.random.randint(0, 100, (2, 16, 3)).astype(np.float32)
209227

210-
# Run PyTorch
211228
pt_cos, pt_sin = rope_pt(torch.from_numpy(np_ids))
212-
213-
# Run JAX
214229
jax_cos, jax_sin = rope_jax(jnp.array(np_ids))
215230

216-
# Compare
217-
pt_cos_np = pt_cos.numpy()
218-
jax_cos_np = np.array(jax_cos)
219-
220231
# 1e-5 tolerance for freq generation math
221-
np.testing.assert_allclose(pt_cos_np, jax_cos_np, atol=1e-5)
232+
np.testing.assert_allclose(pt_cos.numpy(), np.array(jax_cos), atol=1e-5)
222233
np.testing.assert_allclose(pt_sin.numpy(), np.array(jax_sin), atol=1e-5)
223-
print("\n[PASS] RoPE Frequency Generation matches PyTorch.")
234+
print("[PASS] RoPE Frequency Parity Verified.")
224235

225236
# ------------------------------------------
226-
# 2. Strict Parity Test (Full Model)
237+
# 3. Strict Parity Test (Full Model, BF16)
227238
# ------------------------------------------
228239
def test_parity_bf16_strict(self):
229-
"""Checks if JAX(TPU) matches PyTorch(CPU) in BF16."""
240+
"""Checks if JAX matches PyTorch in BF16."""
230241
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16)
231242

232243
pt_in = torch.from_numpy(self.np_x).to(device="cpu", dtype=torch.bfloat16)
233244
jax_in = jnp.array(self.np_x).astype(jnp.bfloat16)
234245

235246
with torch.no_grad():
236-
pt_out = pt_model(pt_in)
247+
pt_out, _ = pt_model(pt_in)
237248

238249
jax_out = jax_model(jax_in)
239250

@@ -247,40 +258,18 @@ def test_parity_bf16_strict(self):
247258
print("\n[PASS] BF16 Strict Parity Test passed.")
248259

249260
# ------------------------------------------
250-
# 3. Layer-wise Stats (Corrected Shape Logic)
261+
# 4. Layer-wise Diagnostics
251262
# ------------------------------------------
252263
def test_layer_wise_stats(self):
253-
"""Prints diagnostic stats for every layer."""
264+
"""Prints diagnostic stats for every layer (Bfloat16)."""
254265
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16)
255266

256267
pt_in = torch.from_numpy(self.np_x).to(device="cpu", dtype=torch.bfloat16)
257268
jax_in = jnp.array(self.np_x).astype(jnp.bfloat16)
258269

259-
# 1. Run PyTorch Step-by-Step
270+
# 1. Run PyTorch Step-by-Step (Get intermediates)
260271
with torch.no_grad():
261-
# Projections
262-
pt_q = pt_model.to_q(pt_in)
263-
pt_k = pt_model.to_k(pt_in)
264-
pt_v = pt_model.to_v(pt_in)
265-
266-
# Norms
267-
pt_qn = pt_model.q_norm(pt_q)
268-
pt_kn = pt_model.k_norm(pt_k)
269-
270-
# Attention Prep (Reshape -> Transpose)
271-
b, s, _ = pt_qn.shape
272-
pt_q_h = pt_qn.view(b, s, self.heads, self.dim_head).transpose(1, 2)
273-
pt_k_h = pt_kn.view(b, s, self.heads, self.dim_head).transpose(1, 2)
274-
pt_v_h = pt_v.view(b, s, self.heads, self.dim_head).transpose(1, 2)
275-
276-
# Attention Op
277-
pt_attn_out = torch.nn.functional.scaled_dot_product_attention(pt_q_h, pt_k_h, pt_v_h)
278-
279-
# Reshape Back
280-
pt_attn_flat = pt_attn_out.transpose(1, 2).reshape(b, s, -1)
281-
282-
# Output
283-
pt_out = pt_model.to_out(pt_attn_flat)
272+
pt_out, (pt_q, pt_k, pt_v, pt_qn, pt_kn, pt_attn) = pt_model(pt_in)
284273

285274
# 2. Run JAX Step-by-Step
286275
jax_q = jax_model.to_q(jax_in)
@@ -290,45 +279,133 @@ def test_layer_wise_stats(self):
290279
jax_qn = jax_model.norm_q(jax_q)
291280
jax_kn = jax_model.norm_k(jax_k)
292281

293-
# Attention Op: Pass [B, S, Inner_Dim] directly
294-
# The LTX2Attention.__call__ flattens inputs before calling apply_attention,
295-
# so we pass the flattened (Inner_Dim) tensors here.
282+
# Pass 3D tensors [B, S, Inner_Dim] directly to attention op
283+
# NNXAttentionOp handles the internal logic for the kernel
296284
jax_attn = jax_model.attention_op.apply_attention(jax_qn, jax_kn, jax_v)
297285
jax_out = jax_model.to_out(jax_attn)
298286

299-
# 3. Print Comparison Table
287+
# 3. Build & Print Comparison Table
300288
stats = []
301289
def add_stat(name, pt_t, jax_t):
302-
pt_val = pt_t.float().numpy() if isinstance(pt_t, torch.Tensor) else pt_t
290+
# Ensure pt_t is a tensor before calling .float().numpy()
291+
if isinstance(pt_t, torch.Tensor):
292+
pt_val = pt_t.float().numpy()
293+
else:
294+
pt_val = pt_t
295+
303296
jax_val = np.array(jax_t, dtype=np.float32)
297+
304298
stats.append({
305299
"Layer": name,
306300
"PT Mean": f"{pt_val.mean():.4f}",
307301
"JAX Mean": f"{jax_val.mean():.4f}",
308302
"PT Min": f"{pt_val.min():.4f}",
309303
"JAX Min": f"{jax_val.min():.4f}",
310-
"PT Max": f"{pt_val.max():.4f}",
311-
"JAX Max": f"{jax_val.max():.4f}",
312-
"Diff (Mean L1)": f"{np.abs(pt_val - jax_val).mean():.6f}"
304+
"Diff (L1)": f"{np.abs(pt_val - jax_val).mean():.6f}"
313305
})
314306

315307
add_stat("Query Proj", pt_q, jax_q)
316308
add_stat("Key Proj", pt_k, jax_k)
317309
add_stat("Value Proj", pt_v, jax_v)
318310
add_stat("Query Norm", pt_qn, jax_qn)
319311
add_stat("Key Norm", pt_kn, jax_kn)
320-
add_stat("Attn Output", pt_attn_flat, jax_attn)
312+
add_stat("Attn Output", pt_attn, jax_attn)
321313
add_stat("Final Output", pt_out, jax_out)
322314

323315
df = pd.DataFrame(stats)
324-
print("\n[DIAGNOSTIC] Layer-wise Stats (CPU vs TPU BFloat16):")
316+
print("\n[DIAGNOSTIC] Layer-wise Stats:")
325317
print(df.to_string(index=False))
318+
326319
# ------------------------------------------
327-
# 4. Cross-Attention + RoPE Integration
320+
# 5. Cross-Attention + RoPE Integration
328321
# ------------------------------------------
329322
def test_cross_attn_rope_integration(self):
330-
"""Verifies Video->Audio cross-attention with RoPE."""
323+
"""Verifies Video->Audio cross-attention with RoPE (Float32)."""
331324
S_Q, S_KV = 16, 20
332325
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.float32)
333326

334-
np_x = np.random.randn(self.B, S_Q, self.D).astype(np.float32)
327+
np_x = np.random.randn(self.B, S_Q, self.D).astype(np.float32)
328+
np_ctx = np.random.randn(self.B, S_KV, self.D).astype(np.float32)
329+
330+
rope_gen_pt = PytorchLTX2RotaryPosEmbed(dim=64)
331+
332+
ids_q = torch.randint(0, 100, (self.B, S_Q, 1))
333+
ids_k = torch.randint(0, 100, (self.B, S_KV, 1))
334+
335+
q_cos_pt, q_sin_pt = rope_gen_pt(ids_q.float())
336+
k_cos_pt, k_sin_pt = rope_gen_pt(ids_k.float())
337+
338+
def prep_pt(c, s):
339+
c = c.view(self.B, -1, self.heads, self.dim_head).transpose(1, 2)
340+
s = s.view(self.B, -1, self.heads, self.dim_head).transpose(1, 2)
341+
return c, s
342+
343+
pt_q_rope = prep_pt(q_cos_pt, q_sin_pt)
344+
pt_k_rope = prep_pt(k_cos_pt, k_sin_pt)
345+
346+
with torch.no_grad():
347+
pt_out, _ = pt_model(
348+
torch.from_numpy(np_x),
349+
context=torch.from_numpy(np_ctx),
350+
q_rope=pt_q_rope,
351+
k_rope=pt_k_rope
352+
)
353+
354+
jax_q_rope = (jnp.array(q_cos_pt.numpy()), jnp.array(q_sin_pt.numpy()))
355+
jax_k_rope = (jnp.array(k_cos_pt.numpy()), jnp.array(k_sin_pt.numpy()))
356+
357+
jax_out = jax_model(
358+
jnp.array(np_x),
359+
encoder_hidden_states=jnp.array(np_ctx),
360+
rotary_emb=jax_q_rope,
361+
k_rotary_emb=jax_k_rope
362+
)
363+
364+
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
365+
print(f"\n[Cross-Attn + RoPE] Max Diff: {diff:.6f}")
366+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
367+
print("[PASS] Cross-Attention with RoPE Parity Verified.")
368+
369+
# ------------------------------------------
370+
# 6. Attention Mask Parity
371+
# ------------------------------------------
372+
def test_attention_mask_parity(self):
373+
"""
374+
Verifies attention masks (padding) work identically using FLASH kernel.
375+
Flash kernel in attention_flax expects a multiplicative mask [B, S],
376+
while PyTorch SDPA expects an additive mask broadcastable to [B,H,S,S].
377+
"""
378+
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.float32)
379+
380+
# Switch JAX model to use flash attention for this test
381+
jax_model.attention_op.attention_kernel = "flash"
382+
jax_model.attention_op.mesh = Mesh(jax.devices(), ('x',))
383+
384+
np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
385+
386+
# Create mask pattern: 1 = keep, 0 = mask out
387+
# Shape: [B, S]
388+
mask_pattern_np = np.random.randint(0, 2, (self.B, self.S)).astype(np.float32)
389+
390+
# PyTorch needs ADDITIVE mask: 0 for keep, -inf for mask out
391+
# Broadcastable to [B, H, S_q, S_kv]: [1, 1, 1, 16] is ok for B=1,H=4,S=16
392+
pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]
393+
394+
# JAX Flash attention needs MULTIPLICATIVE mask: [1, 16]
395+
jax_mask_multiplicative = jnp.array(mask_pattern_np)
396+
397+
# PyTorch
398+
with torch.no_grad():
399+
pt_out, _ = pt_model(torch.from_numpy(np_x), mask=pt_mask_additive)
400+
401+
# JAX
402+
jax_out = jax_model(
403+
jnp.array(np_x),
404+
attention_mask=jax_mask_multiplicative
405+
)
406+
407+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
408+
print("[PASS] Attention Mask Parity Verified.")
409+
410+
if __name__ == "__main__":
411+
unittest.main()

0 commit comments

Comments
 (0)