Skip to content

Commit a922596

Browse files
committed
test
1 parent edb8a6e commit a922596

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,11 @@ def test_attention_mask_parity(self):
412412
pt_out, _ = pt_model(torch.from_numpy(np_x), mask=pt_mask_additive)
413413

414414
# JAX
415-
jax_out = jax_model(
416-
jnp.array(np_x),
417-
attention_mask=jax_mask_multiplicative
418-
)
415+
with jax_model.attention_op.mesh:
416+
jax_out = jax_model(
417+
jnp.array(np_x),
418+
attention_mask=jax_mask_multiplicative
419+
)
419420

420421
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
421422
print("[PASS] Attention Mask Parity Verified.")

0 commit comments

Comments
 (0)