Skip to content

Commit 88984da

Browse files
committed
forcing vae replication
1 parent fd7fa20 commit 88984da

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,15 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13301330
if output_type == "latent":
13311331
return LTX2PipelineOutput(frames=latents, audio=audio_latents)
13321332

1333+
# EXPERIMENT: Force latents to be fully replicated using with_sharding_constraint
1334+
try:
1335+
mesh = latents.sharding.mesh
1336+
replicated_sharding = NamedSharding(mesh, P())
1337+
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
1338+
max_logging.log("[Tuning] Applied replication constraint using with_sharding_constraint.")
1339+
except Exception as e:
1340+
max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}")
1341+
13331342
s_vae = time.perf_counter()
13341343
if getattr(self.vae.config, "timestep_conditioning", False):
13351344
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)

0 commit comments

Comments
 (0)