Skip to content

Commit eff54f7

Browse files
committed
removed failed attempts
1 parent 38f7162 commit eff54f7

1 file changed

Lines changed: 3 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,8 +1255,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12551255
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
12561256
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
12571257
# Audio guidance
1258-
# Replicate noise_pred_audio to avoid cross-device communication during CFG
1259-
noise_pred_audio = jax.device_put(noise_pred_audio, NamedSharding(self.mesh, P()))
1258+
12601259
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
12611260
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
12621261

@@ -1279,15 +1278,13 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12791278
latents_jax = latents_step
12801279
audio_latents_jax = audio_latents_step
12811280

1282-
# Replicate audio latents to avoid sharding accumulation issues
1283-
audio_latents_jax = jax.device_put(audio_latents_jax, NamedSharding(self.mesh, P()))
1281+
12841282

12851283
# 8. Decode Latents
12861284
if guidance_scale > 1.0:
12871285
latents_jax = latents_jax[batch_size:]
12881286
audio_latents_jax = audio_latents_jax[batch_size:]
1289-
# Replicate audio latents to all devices to avoid sharding issues on decoding
1290-
audio_latents_jax = jax.device_put(audio_latents_jax, NamedSharding(self.mesh, P()))
1287+
12911288

12921289
# Unpack and Denormalize Video
12931290
latents = self._unpack_latents(

0 commit comments

Comments
 (0)