Skip to content

Commit 13bcbb8

Browse files
committed
Pass transformer as broadcasted argument to nnx.scan to fix TraceContextError
1 parent 70b78b3 commit 13bcbb8

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,9 +1577,10 @@ def run_diffusion_loop(
15771577
scheduler_step,
15781578
logical_axis_rules,
15791579
):
1580-
def scan_body(carry, t):
1580+
transformer = nnx.merge(graphdef, state)
1581+
1582+
def scan_body(carry, t, model):
15811583
latents, audio_latents, s_state = carry
1582-
transformer = nnx.merge(graphdef, state)
15831584

15841585
with nn_partitioning.axis_rules(logical_axis_rules):
15851586
latents_sharded = latents
@@ -1599,7 +1600,7 @@ def scan_body(carry, t):
15991600
# Expand timestep to batch size
16001601
t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0])
16011602

1602-
noise_pred, noise_pred_audio = transformer(
1603+
noise_pred, noise_pred_audio = model(
16031604
hidden_states=latents_sharded,
16041605
encoder_hidden_states=video_embeds_sharded,
16051606
timestep=t_expanded,
@@ -1660,8 +1661,8 @@ def scan_body(carry, t):
16601661
# Run scan
16611662
final_carry, _ = nnx.scan(
16621663
scan_body,
1663-
in_axes=(nnx.Carry, 0),
1664+
in_axes=(nnx.Carry, 0, None),
16641665
out_axes=(nnx.Carry, 0),
1665-
)(initial_carry, timesteps_jax)
1666+
)(initial_carry, timesteps_jax, transformer)
16661667

16671668
return final_carry[0], final_carry[1]

0 commit comments

Comments
 (0)