Skip to content

Commit 5c98903

Browse files
committed
annotation for denoising loop
1 parent 18abd61 commit 5c98903

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,8 @@ def step_fn(carry, t):
13021302
return (new_latents_jax.astype(latents_jax.dtype), new_audio_latents_jax.astype(audio_latents_jax.dtype)), None
13031303

13041304
initial_carry = (latents_jax, audio_latents_jax)
1305-
(latents_jax, audio_latents_jax), _ = jax.lax.scan(step_fn, initial_carry, timesteps_jax)
1305+
with jax.named_scope("denoising_loop"):
1306+
(latents_jax, audio_latents_jax), _ = jax.lax.scan(step_fn, initial_carry, timesteps_jax)
13061307

13071308

13081309

0 commit comments

Comments
 (0)