Skip to content

Commit edb8a6e

Browse files
committed
Fix logic
1 parent b0312a0 commit edb8a6e

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from flax import nnx
2323
import pandas as pd
2424
from jax.sharding import Mesh
25+
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2526

2627
# Set JAX to use float32 for higher precision checks
2728
jax.config.update("jax_default_matmul_precision", "float32")
@@ -382,6 +383,16 @@ def test_attention_mask_parity(self):
382383
# Switch JAX model to use flash attention for this test
383384
jax_model.attention_op.attention_kernel = "flash"
384385
jax_model.attention_op.mesh = Mesh(np.array(jax.devices()).reshape(1,-1), ('data', 'context'))
386+
jax_model.attention_op.flash_block_sizes = splash_attention_kernel.BlockSizes(
387+
block_q=512,
388+
block_kv_compute=128,
389+
block_kv=128,
390+
block_q_dkv=512,
391+
block_kv_dkv=128,
392+
block_kv_dkv_compute=128,
393+
block_q_dq=512,
394+
block_kv_dq=128,
395+
)
385396

386397
np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
387398

0 commit comments

Comments
 (0)