Skip to content

Commit d3a22e4

Browse files
committed
replicating vae params as well
1 parent 88984da commit d3a22e4

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,12 +1330,21 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13301330
if output_type == "latent":
13311331
return LTX2PipelineOutput(frames=latents, audio=audio_latents)
13321332

1333-
# EXPERIMENT: Force latents to be fully replicated using with_sharding_constraint
1333+
# EXPERIMENT: Force latents and VAE weights to be fully replicated using with_sharding_constraint
13341334
try:
13351335
mesh = latents.sharding.mesh
13361336
replicated_sharding = NamedSharding(mesh, P())
13371337
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
1338-
max_logging.log("[Tuning] Applied replication constraint using with_sharding_constraint.")
1338+
1339+
# Replicate VAE weights
1340+
graphdef, state = nnx.split(self.vae)
1341+
state = jax.tree_util.tree_map(
1342+
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x,
1343+
state
1344+
)
1345+
self.vae = nnx.merge(graphdef, state)
1346+
1347+
max_logging.log("[Tuning] Applied replication constraint using with_sharding_constraint for latents and VAE.")
13391348
except Exception as e:
13401349
max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}")
13411350

0 commit comments

Comments
 (0)