Skip to content

Commit f1a5791

Browse files
committed
replicating stacked-batch activations
1 parent afb5173 commit f1a5791

1 file changed

Lines changed: 49 additions & 13 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,17 +1431,37 @@ def _sched_cfg_get(key: str, default):
14311431
else:
14321432
perturbation_mask = None
14331433

1434+
# CFG / STG / modality stack duplicate batch rows as [uncond, cond, ...]. Guidance mixes rows
1435+
# (e.g. cond vs STG perturb) that must refer to the same sample. If axis 0 is sharded across
1436+
# data parallel mesh, those rows land on different chips and guidance is wrong — video can
1437+
# still look plausible while audio (tighter cross-modal coupling) goes silent or garbage.
14341438
if hasattr(self, "mesh") and self.mesh is not None:
1435-
data_sharding_3d = NamedSharding(self.mesh, P())
1436-
data_sharding_2d = NamedSharding(self.mesh, P())
1437-
if hasattr(self, "config") and hasattr(self.config, "data_sharding"):
1438-
data_sharding_3d = NamedSharding(self.mesh, P(*self.config.data_sharding[:3]))
1439-
data_sharding_2d = NamedSharding(self.mesh, P(*self.config.data_sharding[:2]))
1440-
if isinstance(prompt_embeds_jax, list):
1441-
prompt_embeds_jax = [jax.device_put(x, data_sharding_3d) for x in prompt_embeds_jax]
1439+
if do_cfg:
1440+
rep = NamedSharding(self.mesh, P())
1441+
max_logging.log(
1442+
"LTX2: replicating stacked-batch activations on all devices (required for CFG/STG; "
1443+
"data-parallel sharding of batch breaks cross-row guidance)."
1444+
)
1445+
if isinstance(prompt_embeds_jax, list):
1446+
prompt_embeds_jax = [jax.device_put(x, rep) for x in prompt_embeds_jax]
1447+
else:
1448+
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, rep)
1449+
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, rep)
1450+
latents_jax = jax.device_put(latents_jax, rep)
1451+
audio_latents_jax = jax.device_put(audio_latents_jax, rep)
1452+
if perturbation_mask is not None:
1453+
perturbation_mask = jax.device_put(perturbation_mask, rep)
14421454
else:
1443-
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding_3d)
1444-
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding_2d)
1455+
data_sharding_3d = NamedSharding(self.mesh, P())
1456+
data_sharding_2d = NamedSharding(self.mesh, P())
1457+
if hasattr(self, "config") and hasattr(self.config, "data_sharding"):
1458+
data_sharding_3d = NamedSharding(self.mesh, P(*self.config.data_sharding[:3]))
1459+
data_sharding_2d = NamedSharding(self.mesh, P(*self.config.data_sharding[:2]))
1460+
if isinstance(prompt_embeds_jax, list):
1461+
prompt_embeds_jax = [jax.device_put(x, data_sharding_3d) for x in prompt_embeds_jax]
1462+
else:
1463+
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding_3d)
1464+
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding_2d)
14451465

14461466
# GraphDef and State
14471467
graphdef, state = nnx.split(self.transformer)
@@ -1471,21 +1491,33 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14711491
video_embeds_sharded = video_embeds
14721492
audio_embeds_sharded = audio_embeds
14731493

1474-
if not self.transformer.scan_layers:
1494+
if hasattr(self, "mesh") and self.mesh is not None and do_cfg:
1495+
rep = NamedSharding(self.mesh, P())
1496+
video_embeds_sharded = jax.device_put(video_embeds_sharded, rep)
1497+
audio_embeds_sharded = jax.device_put(audio_embeds_sharded, rep)
1498+
new_attention_mask = jax.device_put(new_attention_mask, rep)
1499+
1500+
if not self.transformer.scan_layers and not do_cfg:
14751501
activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
14761502
spec = NamedSharding(self.mesh, P(*activation_axes))
1477-
video_embeds_sharded = jax.device_put(video_embeds, spec)
1478-
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
1503+
video_embeds_sharded = jax.device_put(video_embeds_sharded, spec)
1504+
audio_embeds_sharded = jax.device_put(audio_embeds_sharded, spec)
14791505

14801506
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1507+
guidance_rep = (
1508+
NamedSharding(self.mesh, P())
1509+
if (do_cfg and hasattr(self, "mesh") and self.mesh is not None)
1510+
else None
1511+
)
1512+
14811513
for i in range(len(timesteps_jax)):
14821514
t = timesteps_jax[i]
14831515

14841516
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
14851517
latents_jax_sharded = latents_jax
14861518
audio_latents_jax_sharded = audio_latents_jax
14871519

1488-
if not self.transformer.scan_layers:
1520+
if not self.transformer.scan_layers and not do_cfg:
14891521
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
14901522
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
14911523
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
@@ -1511,6 +1543,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15111543
use_cross_timestep=use_cross_timestep,
15121544
)
15131545

1546+
if guidance_rep is not None:
1547+
noise_pred = jax.device_put(noise_pred, guidance_rep)
1548+
noise_pred_audio = jax.device_put(noise_pred_audio, guidance_rep)
1549+
15141550
do_cfg = guidance_scale > 1.0
15151551
do_stg = stg_scale > 0.0
15161552
sigma_t = sigmas[i]

0 commit comments

Comments
 (0)