@@ -1431,37 +1431,40 @@ def _sched_cfg_get(key: str, default):
14311431 else :
14321432 perturbation_mask = None
14331433
1434- # CFG / STG / modality stack duplicate batch rows as [uncond, cond, ...]. Guidance mixes rows
1435- # (e.g. cond vs STG perturb) that must refer to the same sample. If axis 0 is sharded across
1436- # data parallel mesh, those rows land on different chips and guidance is wrong — video can
1437- # still look plausible while audio (tighter cross-modal coupling) goes silent or garbage.
1434+ # Stacked CFG/STG duplicates batch axis 0 (uncond/cond/STG/...). Guidance splits and combines
1435+ # rows that must refer to the same sample. Shard only seq/embed axes — keep batch axis
1436+ # replicated (None) — not full P() replication (which OOMs), unlike sharding batch on `data`.
1437+ stacked_guidance_batch = latents_jax .shape [0 ] > batch_size
1438+
14381439 if hasattr (self , "mesh" ) and self .mesh is not None :
1439- if do_cfg :
1440- rep = NamedSharding (self .mesh , P ())
1440+ data_sharding_3d = NamedSharding (self .mesh , P ())
1441+ data_sharding_2d = NamedSharding (self .mesh , P ())
1442+ if hasattr (self , "config" ) and getattr (self .config , "data_sharding" , None ):
1443+ ds = tuple (self .config .data_sharding )
1444+ if len (ds ) >= 3 :
1445+ if stacked_guidance_batch :
1446+ data_sharding_3d = NamedSharding (self .mesh , P (None , ds [1 ], ds [2 ]))
1447+ data_sharding_2d = NamedSharding (self .mesh , P (None , ds [1 ]))
1448+ else :
1449+ data_sharding_3d = NamedSharding (self .mesh , P (* ds [:3 ]))
1450+ data_sharding_2d = NamedSharding (self .mesh , P (* ds [:2 ]))
1451+ if stacked_guidance_batch :
14411452 max_logging .log (
1442- "LTX2: replicating stacked-batch activations on all devices (required for CFG/STG ; "
1443- "data-parallel sharding of batch breaks cross- row guidance )."
1453+ "LTX2: stacked guidance — batch dim 0 is not partitioned (replicated) ; "
1454+ "seq/embed use data_sharding so CFG/STG row pairs stay co-located without full P( )."
14441455 )
1445- if isinstance (prompt_embeds_jax , list ):
1446- prompt_embeds_jax = [jax .device_put (x , rep ) for x in prompt_embeds_jax ]
1447- else :
1448- prompt_embeds_jax = jax .device_put (prompt_embeds_jax , rep )
1449- prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , rep )
1450- latents_jax = jax .device_put (latents_jax , rep )
1451- audio_latents_jax = jax .device_put (audio_latents_jax , rep )
1452- if perturbation_mask is not None :
1453- perturbation_mask = jax .device_put (perturbation_mask , rep )
1456+ if isinstance (prompt_embeds_jax , list ):
1457+ prompt_embeds_jax = [jax .device_put (x , data_sharding_3d ) for x in prompt_embeds_jax ]
14541458 else :
1455- data_sharding_3d = NamedSharding (self .mesh , P ())
1456- data_sharding_2d = NamedSharding (self .mesh , P ())
1457- if hasattr (self , "config" ) and hasattr (self .config , "data_sharding" ):
1458- data_sharding_3d = NamedSharding (self .mesh , P (* self .config .data_sharding [:3 ]))
1459- data_sharding_2d = NamedSharding (self .mesh , P (* self .config .data_sharding [:2 ]))
1460- if isinstance (prompt_embeds_jax , list ):
1461- prompt_embeds_jax = [jax .device_put (x , data_sharding_3d ) for x in prompt_embeds_jax ]
1462- else :
1463- prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding_3d )
1464- prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding_2d )
1459+ prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding_3d )
1460+ prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding_2d )
1461+ if stacked_guidance_batch :
1462+ latents_jax = jax .device_put (latents_jax , data_sharding_3d )
1463+ audio_latents_jax = jax .device_put (audio_latents_jax , data_sharding_3d )
1464+ if perturbation_mask is not None :
1465+ perturbation_mask = jax .device_put (
1466+ perturbation_mask , NamedSharding (self .mesh , P (None , None , None ))
1467+ )
14651468
14661469 # GraphDef and State
14671470 graphdef , state = nnx .split (self .transformer )
@@ -1491,24 +1494,23 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14911494 video_embeds_sharded = video_embeds
14921495 audio_embeds_sharded = audio_embeds
14931496
1494- if hasattr (self , "mesh" ) and self .mesh is not None and do_cfg :
1495- rep = NamedSharding (self .mesh , P ())
1496- video_embeds_sharded = jax .device_put (video_embeds_sharded , rep )
1497- audio_embeds_sharded = jax .device_put (audio_embeds_sharded , rep )
1498- new_attention_mask = jax .device_put (new_attention_mask , rep )
1499-
1500- if not self .transformer .scan_layers and not do_cfg :
1497+ if hasattr (self , "mesh" ) and self .mesh is not None and stacked_guidance_batch :
1498+ video_embeds_sharded = jax .device_put (video_embeds_sharded , data_sharding_3d )
1499+ audio_embeds_sharded = jax .device_put (audio_embeds_sharded , data_sharding_3d )
1500+ new_attention_mask = jax .device_put (new_attention_mask , data_sharding_2d )
1501+
1502+ if (
1503+ not self .transformer .scan_layers
1504+ and not stacked_guidance_batch
1505+ and hasattr (self , "mesh" )
1506+ and self .mesh is not None
1507+ ):
15011508 activation_axes = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
15021509 spec = NamedSharding (self .mesh , P (* activation_axes ))
15031510 video_embeds_sharded = jax .device_put (video_embeds_sharded , spec )
15041511 audio_embeds_sharded = jax .device_put (audio_embeds_sharded , spec )
15051512
15061513 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1507- guidance_rep = (
1508- NamedSharding (self .mesh , P ())
1509- if (do_cfg and hasattr (self , "mesh" ) and self .mesh is not None )
1510- else None
1511- )
15121514
15131515 for i in range (len (timesteps_jax )):
15141516 t = timesteps_jax [i ]
@@ -1517,7 +1519,12 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15171519 latents_jax_sharded = latents_jax
15181520 audio_latents_jax_sharded = audio_latents_jax
15191521
1520- if not self .transformer .scan_layers and not do_cfg :
1522+ if (
1523+ not self .transformer .scan_layers
1524+ and not stacked_guidance_batch
1525+ and hasattr (self , "mesh" )
1526+ and self .mesh is not None
1527+ ):
15211528 activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
15221529 latents_jax_sharded = jax .lax .with_sharding_constraint (latents_jax , activation_axis_names )
15231530 audio_latents_jax_sharded = jax .lax .with_sharding_constraint (audio_latents_jax , activation_axis_names )
@@ -1543,13 +1550,11 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15431550 use_cross_timestep = use_cross_timestep ,
15441551 )
15451552
1546- if guidance_rep is not None :
1547- noise_pred = jax .device_put (noise_pred , guidance_rep )
1548- noise_pred_audio = jax .device_put (noise_pred_audio , guidance_rep )
1549-
15501553 do_cfg = guidance_scale > 1.0
15511554 do_stg = stg_scale > 0.0
1552- sigma_t = sigmas [i ]
1555+ # Match diffusers: use scheduler sigmas after set_timesteps_ltx2 (dynamic shift), not the
1556+ # pre-shift `sigmas` passed into retrieve_timesteps.
1557+ sigma_t = scheduler_state .sigmas [i ]
15531558
15541559 # Extract latents_step based on stacking strategy
15551560 if do_cfg and do_stg :
0 commit comments