Skip to content

Commit 23f05d3

Browse files
committed
replicated noise_pred_audio
1 parent fe30953 commit 23f05d3

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12551255
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
12561256
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
12571257
# Audio guidance
1258+
# Replicate noise_pred_audio to avoid cross-device communication during CFG
1259+
noise_pred_audio = jax.device_put(noise_pred_audio, NamedSharding(self.mesh, P()))
12581260
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
12591261
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
12601262

0 commit comments

Comments
 (0)