Skip to content

Commit 2869cfd

Browse files
committed
moving sharding constraint to outside scan loop
1 parent 5c98903 commit 2869cfd

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,13 +1247,6 @@ def step_fn(carry, t):
12471247
video_embeds_sharded = video_embeds
12481248
audio_embeds_sharded = audio_embeds
12491249

1250-
if not self.transformer.scan_layers:
1251-
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1252-
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
1253-
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
1254-
video_embeds_sharded = jax.lax.with_sharding_constraint(video_embeds, activation_axis_names)
1255-
audio_embeds_sharded = jax.lax.with_sharding_constraint(audio_embeds, activation_axis_names)
1256-
12571250
noise_pred, noise_pred_audio = transformer_forward_pass(
12581251
graphdef,
12591252
state,
@@ -1301,6 +1294,13 @@ def step_fn(carry, t):
13011294

13021295
return (new_latents_jax.astype(latents_jax.dtype), new_audio_latents_jax.astype(audio_latents_jax.dtype)), None
13031296

1297+
if not self.transformer.scan_layers:
1298+
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1299+
latents_jax = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
1300+
audio_latents_jax = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
1301+
video_embeds = jax.lax.with_sharding_constraint(video_embeds, activation_axis_names)
1302+
audio_embeds = jax.lax.with_sharding_constraint(audio_embeds, activation_axis_names)
1303+
13041304
initial_carry = (latents_jax, audio_latents_jax)
13051305
with jax.named_scope("denoising_loop"):
13061306
(latents_jax, audio_latents_jax), _ = jax.lax.scan(step_fn, initial_carry, timesteps_jax)

0 commit comments

Comments
 (0)