Skip to content

Commit 3bc6095

Browse files
committed
audio latents replicate
1 parent 6a024fa commit 3bc6095

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12811281
if guidance_scale > 1.0:
12821282
latents_jax = latents_jax[batch_size:]
12831283
audio_latents_jax = audio_latents_jax[batch_size:]
1284+
# Replicate audio latents to all devices to avoid sharding issues on decoding
1285+
audio_latents_jax = jax.device_put(audio_latents_jax, NamedSharding(self.mesh, P()))
12841286

12851287
# Unpack and Denormalize Video
12861288
latents = self._unpack_latents(

0 commit comments

Comments
 (0)