Skip to content

Commit 4b969d6

Browse files
committed
annotation for denoising loop
1 parent 2869cfd commit 4b969d6

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ def step_fn(carry, t):
12921292
new_latents_jax = latents_step
12931293
new_audio_latents_jax = audio_latents_step
12941294

1295-
return (new_latents_jax.astype(latents_jax.dtype), new_audio_latents_jax.astype(audio_latents_jax.dtype)), None
1295+
return (new_latents_jax, new_audio_latents_jax), None
12961296

12971297
if not self.transformer.scan_layers:
12981298
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
@@ -1301,10 +1301,13 @@ def step_fn(carry, t):
13011301
video_embeds = jax.lax.with_sharding_constraint(video_embeds, activation_axis_names)
13021302
audio_embeds = jax.lax.with_sharding_constraint(audio_embeds, activation_axis_names)
13031303

1304-
initial_carry = (latents_jax, audio_latents_jax)
1304+
initial_carry = (latents_jax.astype(jnp.float32), audio_latents_jax.astype(jnp.float32))
13051305
with jax.named_scope("denoising_loop"):
13061306
(latents_jax, audio_latents_jax), _ = jax.lax.scan(step_fn, initial_carry, timesteps_jax)
13071307

1308+
latents_jax = latents_jax.astype(jnp.bfloat16)
1309+
audio_latents_jax = audio_latents_jax.astype(jnp.bfloat16)
1310+
13081311

13091312

13101313
# 8. Decode Latents

0 commit comments

Comments
 (0)