@@ -1431,17 +1431,37 @@ def _sched_cfg_get(key: str, default):
14311431 else :
14321432 perturbation_mask = None
14331433
1434+ # CFG / STG / modality stack duplicate batch rows as [uncond, cond, ...]. Guidance mixes rows
1435+ # (e.g. cond vs STG perturb) that must refer to the same sample. If axis 0 is sharded across
1436+ # data parallel mesh, those rows land on different chips and guidance is wrong — video can
1437+ # still look plausible while audio (tighter cross-modal coupling) goes silent or garbage.
14341438 if hasattr (self , "mesh" ) and self .mesh is not None :
1435- data_sharding_3d = NamedSharding (self .mesh , P ())
1436- data_sharding_2d = NamedSharding (self .mesh , P ())
1437- if hasattr (self , "config" ) and hasattr (self .config , "data_sharding" ):
1438- data_sharding_3d = NamedSharding (self .mesh , P (* self .config .data_sharding [:3 ]))
1439- data_sharding_2d = NamedSharding (self .mesh , P (* self .config .data_sharding [:2 ]))
1440- if isinstance (prompt_embeds_jax , list ):
1441- prompt_embeds_jax = [jax .device_put (x , data_sharding_3d ) for x in prompt_embeds_jax ]
1439+ if do_cfg :
1440+ rep = NamedSharding (self .mesh , P ())
1441+ max_logging .log (
1442+ "LTX2: replicating stacked-batch activations on all devices (required for CFG/STG; "
1443+ "data-parallel sharding of batch breaks cross-row guidance)."
1444+ )
1445+ if isinstance (prompt_embeds_jax , list ):
1446+ prompt_embeds_jax = [jax .device_put (x , rep ) for x in prompt_embeds_jax ]
1447+ else :
1448+ prompt_embeds_jax = jax .device_put (prompt_embeds_jax , rep )
1449+ prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , rep )
1450+ latents_jax = jax .device_put (latents_jax , rep )
1451+ audio_latents_jax = jax .device_put (audio_latents_jax , rep )
1452+ if perturbation_mask is not None :
1453+ perturbation_mask = jax .device_put (perturbation_mask , rep )
14421454 else :
1443- prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding_3d )
1444- prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding_2d )
1455+ data_sharding_3d = NamedSharding (self .mesh , P ())
1456+ data_sharding_2d = NamedSharding (self .mesh , P ())
1457+ if hasattr (self , "config" ) and hasattr (self .config , "data_sharding" ):
1458+ data_sharding_3d = NamedSharding (self .mesh , P (* self .config .data_sharding [:3 ]))
1459+ data_sharding_2d = NamedSharding (self .mesh , P (* self .config .data_sharding [:2 ]))
1460+ if isinstance (prompt_embeds_jax , list ):
1461+ prompt_embeds_jax = [jax .device_put (x , data_sharding_3d ) for x in prompt_embeds_jax ]
1462+ else :
1463+ prompt_embeds_jax = jax .device_put (prompt_embeds_jax , data_sharding_3d )
1464+ prompt_attention_mask_jax = jax .device_put (prompt_attention_mask_jax , data_sharding_2d )
14451465
14461466 # GraphDef and State
14471467 graphdef , state = nnx .split (self .transformer )
@@ -1471,21 +1491,33 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14711491 video_embeds_sharded = video_embeds
14721492 audio_embeds_sharded = audio_embeds
14731493
1474- if not self .transformer .scan_layers :
1494+ if hasattr (self , "mesh" ) and self .mesh is not None and do_cfg :
1495+ rep = NamedSharding (self .mesh , P ())
1496+ video_embeds_sharded = jax .device_put (video_embeds_sharded , rep )
1497+ audio_embeds_sharded = jax .device_put (audio_embeds_sharded , rep )
1498+ new_attention_mask = jax .device_put (new_attention_mask , rep )
1499+
1500+ if not self .transformer .scan_layers and not do_cfg :
14751501 activation_axes = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
14761502 spec = NamedSharding (self .mesh , P (* activation_axes ))
1477- video_embeds_sharded = jax .device_put (video_embeds , spec )
1478- audio_embeds_sharded = jax .device_put (audio_embeds , spec )
1503+ video_embeds_sharded = jax .device_put (video_embeds_sharded , spec )
1504+ audio_embeds_sharded = jax .device_put (audio_embeds_sharded , spec )
14791505
14801506 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1507+ guidance_rep = (
1508+ NamedSharding (self .mesh , P ())
1509+ if (do_cfg and hasattr (self , "mesh" ) and self .mesh is not None )
1510+ else None
1511+ )
1512+
14811513 for i in range (len (timesteps_jax )):
14821514 t = timesteps_jax [i ]
14831515
14841516 # Isolate input sharding to scan_layers=False to avoid affecting the standard path
14851517 latents_jax_sharded = latents_jax
14861518 audio_latents_jax_sharded = audio_latents_jax
14871519
1488- if not self .transformer .scan_layers :
1520+ if not self .transformer .scan_layers and not do_cfg :
14891521 activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
14901522 latents_jax_sharded = jax .lax .with_sharding_constraint (latents_jax , activation_axis_names )
14911523 audio_latents_jax_sharded = jax .lax .with_sharding_constraint (audio_latents_jax , activation_axis_names )
@@ -1511,6 +1543,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15111543 use_cross_timestep = use_cross_timestep ,
15121544 )
15131545
1546+ if guidance_rep is not None :
1547+ noise_pred = jax .device_put (noise_pred , guidance_rep )
1548+ noise_pred_audio = jax .device_put (noise_pred_audio , guidance_rep )
1549+
15141550 do_cfg = guidance_scale > 1.0
15151551 do_stg = stg_scale > 0.0
15161552 sigma_t = sigmas [i ]
0 commit comments