Skip to content

Commit b59c094

Browse files
committed
Cast scheduler output back to input dtype in scan body
1 parent 13bcbb8 commit b59c094

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,9 +1641,12 @@ def scan_body(carry, t, model):
16411641
latents_step, _ = scheduler_step(
16421642
s_state, noise_pred, t, latents_step, return_dict=False
16431643
)
1644+
latents_step = latents_step.astype(latents.dtype)
1645+
16441646
audio_latents_step, _ = scheduler_step(
16451647
s_state, noise_pred_audio, t, audio_latents_step, return_dict=False
16461648
)
1649+
audio_latents_step = audio_latents_step.astype(audio_latents.dtype)
16471650

16481651
if guidance_scale > 1.0:
16491652
latents_next = jnp.concatenate([latents_step] * 2, axis=0)

0 commit comments

Comments
 (0)