Skip to content

Commit 70b78b3

Browse files
committed
Replace jax.lax.scan with nnx.scan to fix TraceContextError
1 parent cb63764 commit 70b78b3

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,10 @@ def scan_body(carry, t):
16581658
initial_carry = (latents_jax, audio_latents_jax, scheduler_state)
16591659

16601660
# Run scan
1661-
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps_jax)
1661+
final_carry, _ = nnx.scan(
1662+
scan_body,
1663+
in_axes=(nnx.Carry, 0),
1664+
out_axes=(nnx.Carry, 0),
1665+
)(initial_carry, timesteps_jax)
16621666

16631667
return final_carry[0], final_carry[1]

0 commit comments

Comments
 (0)