@@ -1373,15 +1373,23 @@ def __call__(
13731373 )
13741374 else :
13751375 # Old Python loop path
1376- latents_jax = latents_jax .astype (jnp .float32 )
1377- audio_latents_jax = audio_latents_jax .astype (jnp .float32 )
1376+ for i in range (len (timesteps_jax )):
1377+ t = timesteps_jax [i ]
1378+
1379+ # Isolate input sharding to scan_layers=False to avoid affecting the standard path
1380+ latents_jax_sharded = latents_jax
1381+ audio_latents_jax_sharded = audio_latents_jax
1382+
1383+ if not self .transformer .scan_layers :
1384+ activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1385+ latents_jax_sharded = jax .lax .with_sharding_constraint (latents_jax , activation_axis_names )
1386+ audio_latents_jax_sharded = jax .lax .with_sharding_constraint (audio_latents_jax , activation_axis_names )
13781387
1379- for t in timesteps_jax :
13801388 noise_pred , noise_pred_audio = transformer_forward_pass (
13811389 graphdef ,
13821390 state ,
1383- latents_jax ,
1384- audio_latents_jax ,
1391+ latents_jax_sharded ,
1392+ audio_latents_jax_sharded ,
13851393 t ,
13861394 video_embeds_sharded ,
13871395 audio_embeds_sharded ,
@@ -1399,7 +1407,7 @@ def __call__(
13991407 if guidance_scale > 1.0 :
14001408 noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
14011409 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1402-
1410+ # Audio guidance
14031411 noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
14041412 noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
14051413
@@ -1409,13 +1417,12 @@ def __call__(
14091417 latents_step = latents_jax
14101418 audio_latents_step = audio_latents_jax
14111419
1420+ # Step
14121421 latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
1413- latents_step = latents_step .astype (jnp .float32 )
1414-
1422+
14151423 audio_latents_step , _ = self .scheduler .step (
14161424 scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
14171425 )
1418- audio_latents_step = audio_latents_step .astype (jnp .float32 )
14191426
14201427 if guidance_scale > 1.0 :
14211428 latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
0 commit comments