@@ -1255,8 +1255,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12551255 noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
12561256 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
12571257 # Audio guidance
1258- # Replicate noise_pred_audio to avoid cross-device communication during CFG
1259- noise_pred_audio = jax .device_put (noise_pred_audio , NamedSharding (self .mesh , P ()))
1258+
12601259 noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
12611260 noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
12621261
@@ -1279,15 +1278,13 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12791278 latents_jax = latents_step
12801279 audio_latents_jax = audio_latents_step
12811280
1282- # Replicate audio latents to avoid sharding accumulation issues
1283- audio_latents_jax = jax .device_put (audio_latents_jax , NamedSharding (self .mesh , P ()))
1281+
12841282
12851283 # 8. Decode Latents
12861284 if guidance_scale > 1.0 :
12871285 latents_jax = latents_jax [batch_size :]
12881286 audio_latents_jax = audio_latents_jax [batch_size :]
1289- # Replicate audio latents to all devices to avoid sharding issues on decoding
1290- audio_latents_jax = jax .device_put (audio_latents_jax , NamedSharding (self .mesh , P ()))
1287+
12911288
12921289 # Unpack and Denormalize Video
12931290 latents = self ._unpack_latents (
0 commit comments