Skip to content

Commit cb63764

Browse files
committed
Move nnx.merge inside scan_body to fix TraceContextError
1 parent de45665 commit cb63764

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,10 +1577,9 @@ def run_diffusion_loop(
15771577
scheduler_step,
15781578
logical_axis_rules,
15791579
):
1580-
transformer = nnx.merge(graphdef, state)
1581-
15821580
def scan_body(carry, t):
15831581
latents, audio_latents, s_state = carry
1582+
transformer = nnx.merge(graphdef, state)
15841583

15851584
with nn_partitioning.axis_rules(logical_axis_rules):
15861585
latents_sharded = latents

0 commit comments

Comments
 (0)