Skip to content

Commit 18abd61

Browse files
committed
JIT whole diffusion loop
1 parent ef35f7d commit 18abd61

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1299,7 +1299,7 @@ def step_fn(carry, t):
12991299
new_latents_jax = latents_step
13001300
new_audio_latents_jax = audio_latents_step
13011301

1302-
return (new_latents_jax, new_audio_latents_jax), None
1302+
return (new_latents_jax.astype(latents_jax.dtype), new_audio_latents_jax.astype(audio_latents_jax.dtype)), None
13031303

13041304
initial_carry = (latents_jax, audio_latents_jax)
13051305
(latents_jax, audio_latents_jax), _ = jax.lax.scan(step_fn, initial_carry, timesteps_jax)

0 commit comments

Comments
 (0)