Skip to content

Commit e79b523

Browse files
committed
revert replication+other fixes
1 parent f1a5791 commit e79b523

1 file changed

Lines changed: 50 additions & 45 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,37 +1431,40 @@ def _sched_cfg_get(key: str, default):
14311431
else:
14321432
perturbation_mask = None
14331433

1434-
# CFG / STG / modality stack duplicate batch rows as [uncond, cond, ...]. Guidance mixes rows
1435-
# (e.g. cond vs STG perturb) that must refer to the same sample. If axis 0 is sharded across
1436-
# data parallel mesh, those rows land on different chips and guidance is wrong — video can
1437-
# still look plausible while audio (tighter cross-modal coupling) goes silent or garbage.
1434+
# Stacked CFG/STG duplicates batch axis 0 (uncond/cond/STG/...). Guidance splits and combines
1435+
# rows that must refer to the same sample. Shard only seq/embed axes — keep batch axis
1436+
# replicated (None) — not full P() replication (which OOMs), unlike sharding batch on `data`.
1437+
stacked_guidance_batch = latents_jax.shape[0] > batch_size
1438+
14381439
if hasattr(self, "mesh") and self.mesh is not None:
1439-
if do_cfg:
1440-
rep = NamedSharding(self.mesh, P())
1440+
data_sharding_3d = NamedSharding(self.mesh, P())
1441+
data_sharding_2d = NamedSharding(self.mesh, P())
1442+
if hasattr(self, "config") and getattr(self.config, "data_sharding", None):
1443+
ds = tuple(self.config.data_sharding)
1444+
if len(ds) >= 3:
1445+
if stacked_guidance_batch:
1446+
data_sharding_3d = NamedSharding(self.mesh, P(None, ds[1], ds[2]))
1447+
data_sharding_2d = NamedSharding(self.mesh, P(None, ds[1]))
1448+
else:
1449+
data_sharding_3d = NamedSharding(self.mesh, P(*ds[:3]))
1450+
data_sharding_2d = NamedSharding(self.mesh, P(*ds[:2]))
1451+
if stacked_guidance_batch:
14411452
max_logging.log(
1442-
"LTX2: replicating stacked-batch activations on all devices (required for CFG/STG; "
1443-
"data-parallel sharding of batch breaks cross-row guidance)."
1453+
"LTX2: stacked guidance — batch dim 0 is not partitioned (replicated); "
1454+
"seq/embed use data_sharding so CFG/STG row pairs stay co-located without full P()."
14441455
)
1445-
if isinstance(prompt_embeds_jax, list):
1446-
prompt_embeds_jax = [jax.device_put(x, rep) for x in prompt_embeds_jax]
1447-
else:
1448-
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, rep)
1449-
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, rep)
1450-
latents_jax = jax.device_put(latents_jax, rep)
1451-
audio_latents_jax = jax.device_put(audio_latents_jax, rep)
1452-
if perturbation_mask is not None:
1453-
perturbation_mask = jax.device_put(perturbation_mask, rep)
1456+
if isinstance(prompt_embeds_jax, list):
1457+
prompt_embeds_jax = [jax.device_put(x, data_sharding_3d) for x in prompt_embeds_jax]
14541458
else:
1455-
data_sharding_3d = NamedSharding(self.mesh, P())
1456-
data_sharding_2d = NamedSharding(self.mesh, P())
1457-
if hasattr(self, "config") and hasattr(self.config, "data_sharding"):
1458-
data_sharding_3d = NamedSharding(self.mesh, P(*self.config.data_sharding[:3]))
1459-
data_sharding_2d = NamedSharding(self.mesh, P(*self.config.data_sharding[:2]))
1460-
if isinstance(prompt_embeds_jax, list):
1461-
prompt_embeds_jax = [jax.device_put(x, data_sharding_3d) for x in prompt_embeds_jax]
1462-
else:
1463-
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding_3d)
1464-
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding_2d)
1459+
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding_3d)
1460+
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding_2d)
1461+
if stacked_guidance_batch:
1462+
latents_jax = jax.device_put(latents_jax, data_sharding_3d)
1463+
audio_latents_jax = jax.device_put(audio_latents_jax, data_sharding_3d)
1464+
if perturbation_mask is not None:
1465+
perturbation_mask = jax.device_put(
1466+
perturbation_mask, NamedSharding(self.mesh, P(None, None, None))
1467+
)
14651468

14661469
# GraphDef and State
14671470
graphdef, state = nnx.split(self.transformer)
@@ -1491,24 +1494,23 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14911494
video_embeds_sharded = video_embeds
14921495
audio_embeds_sharded = audio_embeds
14931496

1494-
if hasattr(self, "mesh") and self.mesh is not None and do_cfg:
1495-
rep = NamedSharding(self.mesh, P())
1496-
video_embeds_sharded = jax.device_put(video_embeds_sharded, rep)
1497-
audio_embeds_sharded = jax.device_put(audio_embeds_sharded, rep)
1498-
new_attention_mask = jax.device_put(new_attention_mask, rep)
1499-
1500-
if not self.transformer.scan_layers and not do_cfg:
1497+
if hasattr(self, "mesh") and self.mesh is not None and stacked_guidance_batch:
1498+
video_embeds_sharded = jax.device_put(video_embeds_sharded, data_sharding_3d)
1499+
audio_embeds_sharded = jax.device_put(audio_embeds_sharded, data_sharding_3d)
1500+
new_attention_mask = jax.device_put(new_attention_mask, data_sharding_2d)
1501+
1502+
if (
1503+
not self.transformer.scan_layers
1504+
and not stacked_guidance_batch
1505+
and hasattr(self, "mesh")
1506+
and self.mesh is not None
1507+
):
15011508
activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
15021509
spec = NamedSharding(self.mesh, P(*activation_axes))
15031510
video_embeds_sharded = jax.device_put(video_embeds_sharded, spec)
15041511
audio_embeds_sharded = jax.device_put(audio_embeds_sharded, spec)
15051512

15061513
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1507-
guidance_rep = (
1508-
NamedSharding(self.mesh, P())
1509-
if (do_cfg and hasattr(self, "mesh") and self.mesh is not None)
1510-
else None
1511-
)
15121514

15131515
for i in range(len(timesteps_jax)):
15141516
t = timesteps_jax[i]
@@ -1517,7 +1519,12 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15171519
latents_jax_sharded = latents_jax
15181520
audio_latents_jax_sharded = audio_latents_jax
15191521

1520-
if not self.transformer.scan_layers and not do_cfg:
1522+
if (
1523+
not self.transformer.scan_layers
1524+
and not stacked_guidance_batch
1525+
and hasattr(self, "mesh")
1526+
and self.mesh is not None
1527+
):
15211528
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
15221529
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
15231530
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
@@ -1543,13 +1550,11 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15431550
use_cross_timestep=use_cross_timestep,
15441551
)
15451552

1546-
if guidance_rep is not None:
1547-
noise_pred = jax.device_put(noise_pred, guidance_rep)
1548-
noise_pred_audio = jax.device_put(noise_pred_audio, guidance_rep)
1549-
15501553
do_cfg = guidance_scale > 1.0
15511554
do_stg = stg_scale > 0.0
1552-
sigma_t = sigmas[i]
1555+
# Match diffusers: use scheduler sigmas after set_timesteps_ltx2 (dynamic shift), not the
1556+
# pre-shift `sigmas` passed into retrieve_timesteps.
1557+
sigma_t = scheduler_state.sigmas[i]
15531558

15541559
# Extract latents_step based on stacking strategy
15551560
if do_cfg and do_stg:

0 commit comments

Comments
 (0)