Skip to content

Commit e7aceb3

Browse files
committed
Force float32 for latents in run_diffusion_loop to test precision effect
1 parent edea34d commit e7aceb3

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,8 @@ def run_diffusion_loop(
15771577
scheduler_step,
15781578
logical_axis_rules,
15791579
):
1580+
latents_jax = latents_jax.astype(jnp.float32)
1581+
audio_latents_jax = audio_latents_jax.astype(jnp.float32)
15801582
transformer = nnx.merge(graphdef, state)
15811583

15821584
def scan_body(carry, t, model):

0 commit comments

Comments
 (0)