Skip to content

Commit fe30953

Browse files
committed
audio latents replicate
1 parent 3bc6095 commit fe30953

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12771277
latents_jax = latents_step
12781278
audio_latents_jax = audio_latents_step
12791279

1280+
# Replicate audio latents to avoid sharding accumulation issues
1281+
audio_latents_jax = jax.device_put(audio_latents_jax, NamedSharding(self.mesh, P()))
1282+
12801283
# 8. Decode Latents
12811284
if guidance_scale > 1.0:
12821285
latents_jax = latents_jax[batch_size:]

0 commit comments

Comments
 (0)