Skip to content

Commit 22036e6

Browse files
committed
fix scan layers False issue
1 parent d71aa8d commit 22036e6

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,14 +1240,28 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12401240
total_diffusion_time = 0.0
12411241
for i, t in enumerate(timesteps):
12421242
step_start_time = time.perf_counter()
1243+
1244+
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
1245+
latents_jax_sharded = latents_jax
1246+
audio_latents_jax_sharded = audio_latents_jax
1247+
video_embeds_sharded = video_embeds
1248+
audio_embeds_sharded = audio_embeds
1249+
1250+
if not self.transformer.scan_layers:
1251+
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1252+
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
1253+
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
1254+
video_embeds_sharded = jax.lax.with_sharding_constraint(video_embeds, activation_axis_names)
1255+
audio_embeds_sharded = jax.lax.with_sharding_constraint(audio_embeds, activation_axis_names)
1256+
12431257
noise_pred, noise_pred_audio = transformer_forward_pass(
12441258
graphdef,
12451259
state,
1246-
latents_jax,
1247-
audio_latents_jax,
1260+
latents_jax_sharded,
1261+
audio_latents_jax_sharded,
12481262
t,
1249-
video_embeds,
1250-
audio_embeds,
1263+
video_embeds_sharded,
1264+
audio_embeds_sharded,
12511265
new_attention_mask,
12521266
new_attention_mask,
12531267
guidance_scale > 1.0,

0 commit comments

Comments
 (0)