@@ -1238,8 +1238,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12381238
12391239 import time
12401240 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1241- for i , t_val in enumerate ( timesteps ):
1242- t = timesteps_jax [ i ]
1241+ def step_fn ( carry , t ):
1242+ latents_jax , audio_latents_jax = carry
12431243
12441244 # Isolate input sharding to scan_layers=False to avoid affecting the standard path
12451245 latents_jax_sharded = latents_jax
@@ -1293,11 +1293,16 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12931293 )
12941294
12951295 if guidance_scale > 1.0 :
1296- latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1297- audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1296+ new_latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1297+ new_audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
12981298 else :
1299- latents_jax = latents_step
1300- audio_latents_jax = audio_latents_step
1299+ new_latents_jax = latents_step
1300+ new_audio_latents_jax = audio_latents_step
1301+
1302+ return (new_latents_jax , new_audio_latents_jax ), None
1303+
1304+ initial_carry = (latents_jax , audio_latents_jax )
1305+ (latents_jax , audio_latents_jax ), _ = jax .lax .scan (step_fn , initial_carry , timesteps_jax )
13011306
13021307
13031308
0 commit comments