Skip to content

Commit 63501ca

Browse files
committed
Restore original Python loop path with sharding constraints
1 parent a80b371 commit 63501ca

1 file changed

Lines changed: 16 additions & 9 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,15 +1373,23 @@ def __call__(
13731373
)
13741374
else:
13751375
# Old Python loop path
1376-
latents_jax = latents_jax.astype(jnp.float32)
1377-
audio_latents_jax = audio_latents_jax.astype(jnp.float32)
1376+
for i in range(len(timesteps_jax)):
1377+
t = timesteps_jax[i]
1378+
1379+
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
1380+
latents_jax_sharded = latents_jax
1381+
audio_latents_jax_sharded = audio_latents_jax
1382+
1383+
if not self.transformer.scan_layers:
1384+
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1385+
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
1386+
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
13781387

1379-
for t in timesteps_jax:
13801388
noise_pred, noise_pred_audio = transformer_forward_pass(
13811389
graphdef,
13821390
state,
1383-
latents_jax,
1384-
audio_latents_jax,
1391+
latents_jax_sharded,
1392+
audio_latents_jax_sharded,
13851393
t,
13861394
video_embeds_sharded,
13871395
audio_embeds_sharded,
@@ -1399,7 +1407,7 @@ def __call__(
13991407
if guidance_scale > 1.0:
14001408
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
14011409
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1402-
1410+
# Audio guidance
14031411
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
14041412
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
14051413

@@ -1409,13 +1417,12 @@ def __call__(
14091417
latents_step = latents_jax
14101418
audio_latents_step = audio_latents_jax
14111419

1420+
# Step
14121421
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1413-
latents_step = latents_step.astype(jnp.float32)
1414-
1422+
14151423
audio_latents_step, _ = self.scheduler.step(
14161424
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
14171425
)
1418-
audio_latents_step = audio_latents_step.astype(jnp.float32)
14191426

14201427
if guidance_scale > 1.0:
14211428
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)

0 commit comments

Comments
 (0)