We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent edb8a6e commit a922596Copy full SHA for a922596
1 file changed
src/maxdiffusion/tests/test_attention_ltx2.py
@@ -412,10 +412,11 @@ def test_attention_mask_parity(self):
412
pt_out, _ = pt_model(torch.from_numpy(np_x), mask=pt_mask_additive)
413
414
# JAX
415
- jax_out = jax_model(
416
- jnp.array(np_x),
417
- attention_mask=jax_mask_multiplicative
418
- )
+ with jax_model.attention_op.mesh:
+ jax_out = jax_model(
+ jnp.array(np_x),
+ attention_mask=jax_mask_multiplicative
419
+ )
420
421
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
422
print("[PASS] Attention Mask Parity Verified.")
0 commit comments