Skip to content

Commit 0a48d93

Browse files
committed
unittest fix
1 parent 36534a2 commit 0a48d93

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/tests/ltx2/test_attention_ltx2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def test_attention_mask_parity(self):
378378

379379
jax_model.attention_op.attention_kernel = "flash"
380380
jax_model.attention_op.mesh = mesh
381+
jax_model.attention_op.flash_min_seq_length = 0
381382

382383
mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32)
383384
pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]

0 commit comments

Comments
 (0)