We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a34c7f7 commit a4c1b6eCopy full SHA for a4c1b6e
1 file changed
src/maxdiffusion/tests/ltx2/test_attention_ltx2.py
@@ -378,6 +378,7 @@ def test_attention_mask_parity(self):
378
379
jax_model.attention_op.attention_kernel = "flash"
380
jax_model.attention_op.mesh = mesh
381
+ jax_model.attention_op.flash_min_seq_length = 0
382
383
mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32)
384
pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]
0 commit comments