We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2e9f896 commit a34c7f7Copy full SHA for a34c7f7
1 file changed
src/maxdiffusion/tests/ltx2/test_attention_ltx2.py
@@ -378,7 +378,6 @@ def test_attention_mask_parity(self):
378
379
jax_model.attention_op.attention_kernel = "flash"
380
jax_model.attention_op.mesh = mesh
381
- jax_model.attention_op.flash_min_seq_length = 0
382
383
mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32)
384
pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]
0 commit comments