Skip to content

Commit 184edb8

Browse files
committed
fix in calibration
1 parent 712df7c commit 184edb8

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/scripts/calibrate_ltx2_fbs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_dummy_ltx2_inputs(batch_size, dtype):
2626
# LTX-2 Audio latents default channels = 128
2727
audio_latents = jax.random.normal(rng, (batch_size, 64, 128), dtype=dtype)
2828
audio_prompt_embeds = jax.random.normal(rng, (batch_size, 64, 3072), dtype=dtype)
29+
encoder_attention_mask = jnp.ones((batch_size, 128), dtype=jnp.int32)
2930
audio_encoder_attention_mask = jnp.ones((batch_size, 64), dtype=jnp.int32)
3031

3132
return latents, audio_latents, timestep, prompt_embeds, audio_prompt_embeds, encoder_attention_mask, audio_encoder_attention_mask

0 commit comments

Comments
 (0)