Skip to content

Commit 8dfc613

Browse files
committed
Attention Test
1 parent 529aafc commit 8dfc613

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ def add_stat(name, pt_t, jax_t):
297297

298298
stats.append({
299299
"Layer": name,
300+
"PT Max": f"{pt_val.max():.4f}",
301+
"JAX Max": f"{jax_val.max():.4f}",
300302
"PT Mean": f"{pt_val.mean():.4f}",
301303
"JAX Mean": f"{jax_val.mean():.4f}",
302304
"PT Min": f"{pt_val.min():.4f}",
@@ -379,7 +381,7 @@ def test_attention_mask_parity(self):
379381

380382
# Switch JAX model to use flash attention for this test
381383
jax_model.attention_op.attention_kernel = "flash"
382-
jax_model.attention_op.mesh = Mesh(jax.devices(), ('x',))
384+
jax_model.attention_op.mesh = Mesh(jax.devices(), ('context',))
383385

384386
np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
385387

0 commit comments

Comments
 (0)