@@ -1577,9 +1577,10 @@ def run_diffusion_loop(
15771577 scheduler_step ,
15781578 logical_axis_rules ,
15791579):
1580- def scan_body (carry , t ):
1580+ transformer = nnx .merge (graphdef , state )
1581+
1582+ def scan_body (carry , t , model ):
15811583 latents , audio_latents , s_state = carry
1582- transformer = nnx .merge (graphdef , state )
15831584
15841585 with nn_partitioning .axis_rules (logical_axis_rules ):
15851586 latents_sharded = latents
@@ -1599,7 +1600,7 @@ def scan_body(carry, t):
15991600 # Expand timestep to batch size
16001601 t_expanded = jnp .expand_dims (t , 0 ).repeat (latents .shape [0 ])
16011602
1602- noise_pred , noise_pred_audio = transformer (
1603+ noise_pred , noise_pred_audio = model (
16031604 hidden_states = latents_sharded ,
16041605 encoder_hidden_states = video_embeds_sharded ,
16051606 timestep = t_expanded ,
@@ -1660,8 +1661,8 @@ def scan_body(carry, t):
16601661 # Run scan
16611662 final_carry , _ = nnx .scan (
16621663 scan_body ,
1663- in_axes = (nnx .Carry , 0 ),
1664+ in_axes = (nnx .Carry , 0 , None ),
16641665 out_axes = (nnx .Carry , 0 ),
1665- )(initial_carry , timesteps_jax )
1666+ )(initial_carry , timesteps_jax , transformer )
16661667
16671668 return final_carry [0 ], final_carry [1 ]
0 commit comments