Skip to content

Commit de14eec

Browse files
committed
test
1 parent 266b00a commit de14eec

1 file changed

Lines changed: 23 additions & 20 deletions

File tree

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, ids):
6060
cos = torch.repeat_interleave(cos, 2, dim=-1)
6161
sin = torch.repeat_interleave(sin, 2, dim=-1)
6262

63-
# CORRECT: Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
63+
# Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
6464
return cos, sin
6565

6666

@@ -111,35 +111,37 @@ def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
111111
k = self.to_k(ctx)
112112
v = self.to_v(ctx)
113113

114-
q = self.q_norm(q)
115-
k = self.k_norm(k)
114+
# Keep raw projections for test_layer_wise_stats
115+
q_raw, k_raw = q, k
116+
117+
q_normed = self.q_norm(q)
118+
k_normed = self.k_norm(k)
116119

117-
# CORRECT: Apply RoPE globally BEFORE splitting heads
118120
if q_rope is not None:
119121
q_cos, q_sin = q_rope
120-
q = apply_rotary_emb_pt(q, q_cos, q_sin)
122+
q_normed = apply_rotary_emb_pt(q_normed, q_cos, q_sin)
121123

122124
if k_rope is not None:
123125
k_cos, k_sin = k_rope
124-
k = apply_rotary_emb_pt(k, k_cos, k_sin)
126+
k_normed = apply_rotary_emb_pt(k_normed, k_cos, k_sin)
125127

126128
# Split Heads for Attention
127-
b, s_q, _ = q.shape
128-
_, s_kv, _ = k.shape
129-
q_h = q.view(b, s_q, self.heads, self.dim_head).transpose(1, 2)
130-
k_h = k.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
129+
b, s_q, _ = q_normed.shape
130+
_, s_kv, _ = k_normed.shape
131+
q_h = q_normed.view(b, s_q, self.heads, self.dim_head).transpose(1, 2)
132+
k_h = k_normed.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
131133
v_h = v.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
132134

133135
out = torch.nn.functional.scaled_dot_product_attention(
134136
q_h, k_h, v_h, attn_mask=mask, dropout_p=0.0
135137
)
136138
out = out.transpose(1, 2).reshape(b, s_q, -1)
137-
return self.to_out(out), (q, k, v, q, k, out) # Returning normed q/k as placeholder
139+
return self.to_out(out), (q_raw, k_raw, v, q_normed, k_normed, out)
138140

139141
# ==========================================
140142
# 2. JAX Imports & Test Suite
141143
# ==========================================
142-
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
144+
from ..models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
143145

144146
class LTX2AttentionTest(unittest.TestCase):
145147

@@ -191,11 +193,11 @@ def test_shapes(self):
191193
model = LTX2Attention(64, 4, 16, 64, rngs=self.rng, attention_kernel="dot_product")
192194

193195
x_vid = jnp.zeros((1, 128, 64))
194-
out_vid = model(x_vid, deterministic=True)
196+
out_vid = model(x_vid)
195197
self.assertEqual(out_vid.shape, (1, 128, 64))
196198

197199
x_aud = jnp.zeros((1, 32, 64))
198-
out_cross = model(x_vid, encoder_hidden_states=x_aud, deterministic=True)
200+
out_cross = model(x_vid, encoder_hidden_states=x_aud)
199201
self.assertEqual(out_cross.shape, (1, 128, 64))
200202
print("\n[PASS] Shape Tests Passed.")
201203

@@ -221,7 +223,7 @@ def test_parity_bf16_strict(self):
221223
with torch.no_grad():
222224
pt_out, _ = pt_model(pt_in)
223225

224-
jax_out = jax_model(jax_in, deterministic=True)
226+
jax_out = jax_model(jax_in)
225227

226228
pt_res = pt_out.float().numpy()
227229
jax_res = np.array(jax_out, dtype=np.float32)
@@ -259,8 +261,12 @@ def add_stat(name, pt_t, jax_t):
259261
jax_val = np.array(jax_t, dtype=np.float32)
260262
stats.append({
261263
"Layer": name,
264+
"PT Max": f"{pt_val.max():.4f}",
265+
"JAX Max": f"{jax_val.max():.4f}",
262266
"PT Mean": f"{pt_val.mean():.4f}",
263267
"JAX Mean": f"{jax_val.mean():.4f}",
268+
"PT Min": f"{pt_val.min():.4f}",
269+
"JAX Min": f"{jax_val.min():.4f}",
264270
"Diff (L1)": f"{np.abs(pt_val - jax_val).mean():.6f}"
265271
})
266272

@@ -291,7 +297,6 @@ def test_cross_attn_rope_integration(self):
291297
q_cos_pt, q_sin_pt = rope_gen_pt(ids_q.float())
292298
k_cos_pt, k_sin_pt = rope_gen_pt(ids_k.float())
293299

294-
# No reshape needed! Passed directly as [B, S, InnerDim]
295300
with torch.no_grad():
296301
pt_out, _ = pt_model(
297302
torch.from_numpy(np_x),
@@ -307,8 +312,7 @@ def test_cross_attn_rope_integration(self):
307312
jnp.array(np_x),
308313
encoder_hidden_states=jnp.array(np_ctx),
309314
rotary_emb=jax_q_rope,
310-
k_rotary_emb=jax_k_rope,
311-
deterministic=True
315+
k_rotary_emb=jax_k_rope
312316
)
313317

314318
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
@@ -342,8 +346,7 @@ def test_attention_mask_parity(self):
342346
with mesh:
343347
jax_out = jax_model(
344348
jnp.array(np_x),
345-
attention_mask=jax_mask_multiplicative,
346-
deterministic=True
349+
attention_mask=jax_mask_multiplicative
347350
)
348351

349352
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()

0 commit comments

Comments
 (0)