Skip to content

Commit b0312a0

Browse files
committed
Fix bug
1 parent 8dfc613 commit b0312a0

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def test_attention_mask_parity(self):
381381

382382
# Switch JAX model to use flash attention for this test
383383
jax_model.attention_op.attention_kernel = "flash"
384-
jax_model.attention_op.mesh = Mesh(jax.devices(), ('context',))
384+
jax_model.attention_op.mesh = Mesh(np.array(jax.devices()).reshape(1,-1), ('data', 'context'))
385385

386386
np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
387387

0 commit comments

Comments
 (0)