Skip to content

Commit 346a127

Browse files
committed
replicating latents
1 parent 084dd62 commit 346a127

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,11 @@ def convert_to_x0(lat, vel):
15341534
def convert_to_vel(lat, x0):
15351535
return (lat - x0) / sigma_t
15361536

1537+
# Replicate latents across devices to avoid sharding mismatch during eager split
1538+
replicated_sharding = NamedSharding(self.mesh, P())
1539+
noise_pred = jax.device_put(noise_pred, replicated_sharding)
1540+
noise_pred_audio = jax.device_put(noise_pred_audio, replicated_sharding)
1541+
15371542
if do_cfg and do_stg:
15381543
noise_pred_uncond, noise_pred_text, noise_pred_perturb, noise_pred_isolated = jnp.split(noise_pred, 4, axis=0)
15391544

0 commit comments

Comments
 (0)