@@ -1247,13 +1247,6 @@ def step_fn(carry, t):
12471247 video_embeds_sharded = video_embeds
12481248 audio_embeds_sharded = audio_embeds
12491249
1250- if not self .transformer .scan_layers :
1251- activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1252- latents_jax_sharded = jax .lax .with_sharding_constraint (latents_jax , activation_axis_names )
1253- audio_latents_jax_sharded = jax .lax .with_sharding_constraint (audio_latents_jax , activation_axis_names )
1254- video_embeds_sharded = jax .lax .with_sharding_constraint (video_embeds , activation_axis_names )
1255- audio_embeds_sharded = jax .lax .with_sharding_constraint (audio_embeds , activation_axis_names )
1256-
12571250 noise_pred , noise_pred_audio = transformer_forward_pass (
12581251 graphdef ,
12591252 state ,
@@ -1301,6 +1294,13 @@ def step_fn(carry, t):
13011294
13021295 return (new_latents_jax .astype (latents_jax .dtype ), new_audio_latents_jax .astype (audio_latents_jax .dtype )), None
13031296
1297+ if not self .transformer .scan_layers :
1298+ activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1299+ latents_jax = jax .lax .with_sharding_constraint (latents_jax , activation_axis_names )
1300+ audio_latents_jax = jax .lax .with_sharding_constraint (audio_latents_jax , activation_axis_names )
1301+ video_embeds = jax .lax .with_sharding_constraint (video_embeds , activation_axis_names )
1302+ audio_embeds = jax .lax .with_sharding_constraint (audio_embeds , activation_axis_names )
1303+
13041304 initial_carry = (latents_jax , audio_latents_jax )
13051305 with jax .named_scope ("denoising_loop" ):
13061306 (latents_jax , audio_latents_jax ), _ = jax .lax .scan (step_fn , initial_carry , timesteps_jax )
0 commit comments