@@ -900,113 +900,79 @@ def __call__(
900900 batch_size = hidden_states .shape [0 ]
901901
902902 # 1. Prepare RoPE positional embeddings
903- if video_coords is None :
904- video_coords = self .rope .prepare_video_coords (batch_size , num_frames , height , width , fps = fps )
905- if audio_coords is None :
906- audio_coords = self .audio_rope .prepare_audio_coords (batch_size , audio_num_frames )
903+ with jax .named_scope ("RoPE Preparation" ):
904+ if video_coords is None :
905+ video_coords = self .rope .prepare_video_coords (batch_size , num_frames , height , width , fps = fps )
906+ if audio_coords is None :
907+ audio_coords = self .audio_rope .prepare_audio_coords (batch_size , audio_num_frames )
907908
908- video_rotary_emb = self .rope (video_coords )
909- audio_rotary_emb = self .audio_rope (audio_coords )
909+ video_rotary_emb = self .rope (video_coords )
910+ audio_rotary_emb = self .audio_rope (audio_coords )
910911
911- video_cross_attn_rotary_emb = self .cross_attn_rope (video_coords [:, 0 :1 , :])
912- audio_cross_attn_rotary_emb = self .cross_attn_audio_rope (audio_coords [:, 0 :1 , :])
912+ video_cross_attn_rotary_emb = self .cross_attn_rope (video_coords [:, 0 :1 , :])
913+ audio_cross_attn_rotary_emb = self .cross_attn_audio_rope (audio_coords [:, 0 :1 , :])
913914
914915 # 2. Patchify input projections
915- hidden_states = self .proj_in (hidden_states )
916- audio_hidden_states = self .audio_proj_in (audio_hidden_states )
916+ with jax .named_scope ("Input Projection" ):
917+ hidden_states = self .proj_in (hidden_states )
918+ audio_hidden_states = self .audio_proj_in (audio_hidden_states )
917919
918920 # 3. Prepare timestep embeddings and modulation parameters
919- timestep_cross_attn_gate_scale_factor = self .cross_attn_timestep_scale_multiplier / self .timestep_scale_multiplier
921+ with jax .named_scope ("Timestep and Caption Projection" ):
922+ timestep_cross_attn_gate_scale_factor = self .cross_attn_timestep_scale_multiplier / self .timestep_scale_multiplier
920923
921- temb , embedded_timestep = self .time_embed (
922- timestep .flatten (),
923- hidden_dtype = hidden_states .dtype ,
924- )
925- temb = temb .reshape (batch_size , - 1 , temb .shape [- 1 ])
926- embedded_timestep = embedded_timestep .reshape (batch_size , - 1 , embedded_timestep .shape [- 1 ])
924+ temb , embedded_timestep = self .time_embed (
925+ timestep .flatten (),
926+ hidden_dtype = hidden_states .dtype ,
927+ )
928+ temb = temb .reshape (batch_size , - 1 , temb .shape [- 1 ])
929+ embedded_timestep = embedded_timestep .reshape (batch_size , - 1 , embedded_timestep .shape [- 1 ])
927930
928- temb_audio , audio_embedded_timestep = self .audio_time_embed (
929- audio_timestep .flatten (),
930- hidden_dtype = audio_hidden_states .dtype ,
931- )
932- temb_audio = temb_audio .reshape (batch_size , - 1 , temb_audio .shape [- 1 ])
933- audio_embedded_timestep = audio_embedded_timestep .reshape (batch_size , - 1 , audio_embedded_timestep .shape [- 1 ])
931+ temb_audio , audio_embedded_timestep = self .audio_time_embed (
932+ audio_timestep .flatten (),
933+ hidden_dtype = audio_hidden_states .dtype ,
934+ )
935+ temb_audio = temb_audio .reshape (batch_size , - 1 , temb_audio .shape [- 1 ])
936+ audio_embedded_timestep = audio_embedded_timestep .reshape (batch_size , - 1 , audio_embedded_timestep .shape [- 1 ])
934937
935- video_cross_attn_scale_shift , _ = self .av_cross_attn_video_scale_shift (
936- timestep .flatten (),
937- hidden_dtype = hidden_states .dtype ,
938- )
939- video_cross_attn_a2v_gate , _ = self .av_cross_attn_video_a2v_gate (
940- timestep .flatten () * timestep_cross_attn_gate_scale_factor ,
941- hidden_dtype = hidden_states .dtype ,
942- )
943- video_cross_attn_scale_shift = video_cross_attn_scale_shift .reshape (
944- batch_size , - 1 , video_cross_attn_scale_shift .shape [- 1 ]
945- )
946- video_cross_attn_a2v_gate = video_cross_attn_a2v_gate .reshape (batch_size , - 1 , video_cross_attn_a2v_gate .shape [- 1 ])
938+ video_cross_attn_scale_shift , _ = self .av_cross_attn_video_scale_shift (
939+ timestep .flatten (),
940+ hidden_dtype = hidden_states .dtype ,
941+ )
942+ video_cross_attn_a2v_gate , _ = self .av_cross_attn_video_a2v_gate (
943+ timestep .flatten () * timestep_cross_attn_gate_scale_factor ,
944+ hidden_dtype = hidden_states .dtype ,
945+ )
946+ video_cross_attn_scale_shift = video_cross_attn_scale_shift .reshape (
947+ batch_size , - 1 , video_cross_attn_scale_shift .shape [- 1 ]
948+ )
949+ video_cross_attn_a2v_gate = video_cross_attn_a2v_gate .reshape (batch_size , - 1 , video_cross_attn_a2v_gate .shape [- 1 ])
947950
948- audio_cross_attn_scale_shift , _ = self .av_cross_attn_audio_scale_shift (
949- audio_timestep .flatten (),
950- hidden_dtype = audio_hidden_states .dtype ,
951- )
952- audio_cross_attn_v2a_gate , _ = self .av_cross_attn_audio_v2a_gate (
953- audio_timestep .flatten () * timestep_cross_attn_gate_scale_factor ,
954- hidden_dtype = audio_hidden_states .dtype ,
955- )
956- audio_cross_attn_scale_shift = audio_cross_attn_scale_shift .reshape (
957- batch_size , - 1 , audio_cross_attn_scale_shift .shape [- 1 ]
958- )
959- audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate .reshape (batch_size , - 1 , audio_cross_attn_v2a_gate .shape [- 1 ])
951+ audio_cross_attn_scale_shift , _ = self .av_cross_attn_audio_scale_shift (
952+ audio_timestep .flatten (),
953+ hidden_dtype = audio_hidden_states .dtype ,
954+ )
955+ audio_cross_attn_v2a_gate , _ = self .av_cross_attn_audio_v2a_gate (
956+ audio_timestep .flatten () * timestep_cross_attn_gate_scale_factor ,
957+ hidden_dtype = audio_hidden_states .dtype ,
958+ )
959+ audio_cross_attn_scale_shift = audio_cross_attn_scale_shift .reshape (
960+ batch_size , - 1 , audio_cross_attn_scale_shift .shape [- 1 ]
961+ )
962+ audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate .reshape (batch_size , - 1 , audio_cross_attn_v2a_gate .shape [- 1 ])
960963
961- # 4. Prepare prompt embeddings
962- encoder_hidden_states = self .caption_projection (encoder_hidden_states )
963- encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
964+ # 4. Prepare prompt embeddings
965+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
966+ encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
964967
965- audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
966- audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
968+ audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
969+ audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
967970
968971 # 5. Run transformer blocks
969972 def scan_fn (carry , block ):
970973 hidden_states , audio_hidden_states , rngs_carry = carry
971- hidden_states_out , audio_hidden_states_out = block (
972- hidden_states = hidden_states ,
973- audio_hidden_states = audio_hidden_states ,
974- encoder_hidden_states = encoder_hidden_states ,
975- audio_encoder_hidden_states = audio_encoder_hidden_states ,
976- temb = temb ,
977- temb_audio = temb_audio ,
978- temb_ca_scale_shift = video_cross_attn_scale_shift ,
979- temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
980- temb_ca_gate = video_cross_attn_a2v_gate ,
981- temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
982- video_rotary_emb = video_rotary_emb ,
983- audio_rotary_emb = audio_rotary_emb ,
984- ca_video_rotary_emb = video_cross_attn_rotary_emb ,
985- ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
986- encoder_attention_mask = encoder_attention_mask ,
987- audio_encoder_attention_mask = audio_encoder_attention_mask ,
988- )
989- return (
990- hidden_states_out .astype (hidden_states .dtype ),
991- audio_hidden_states_out .astype (audio_hidden_states .dtype ),
992- rngs_carry ,
993- ), None
994-
995- if self .scan_layers :
996- rematted_scan_fn = self .gradient_checkpoint .apply (
997- scan_fn , self .names_which_can_be_saved , self .names_which_can_be_offloaded , prevent_cse = not self .scan_layers
998- )
999- carry = (hidden_states , audio_hidden_states , nnx .Rngs (0 )) # Placeholder RNGs for now if not used in block
1000- (hidden_states , audio_hidden_states , _ ), _ = nnx .scan (
1001- rematted_scan_fn ,
1002- length = self .num_layers ,
1003- in_axes = (nnx .Carry , 0 ),
1004- out_axes = (nnx .Carry , 0 ),
1005- transform_metadata = {nnx .PARTITION_NAME : "layers" },
1006- )(carry , self .transformer_blocks )
1007- else :
1008- for block in self .transformer_blocks :
1009- hidden_states , audio_hidden_states = block (
974+ with jax .named_scope ("Transformer Block i" ):
975+ hidden_states_out , audio_hidden_states_out = block (
1010976 hidden_states = hidden_states ,
1011977 audio_hidden_states = audio_hidden_states ,
1012978 encoder_hidden_states = encoder_hidden_states ,
@@ -1024,6 +990,45 @@ def scan_fn(carry, block):
1024990 encoder_attention_mask = encoder_attention_mask ,
1025991 audio_encoder_attention_mask = audio_encoder_attention_mask ,
1026992 )
993+ return (
994+ hidden_states_out .astype (hidden_states .dtype ),
995+ audio_hidden_states_out .astype (audio_hidden_states .dtype ),
996+ rngs_carry ,
997+ ), None
998+
999+ with jax .named_scope ("Transformer Blocks" ):
1000+ if self .scan_layers :
1001+ rematted_scan_fn = self .gradient_checkpoint .apply (
1002+ scan_fn , self .names_which_can_be_saved , self .names_which_can_be_offloaded , prevent_cse = not self .scan_layers
1003+ )
1004+ carry = (hidden_states , audio_hidden_states , nnx .Rngs (0 )) # Placeholder RNGs for now if not used in block
1005+ (hidden_states , audio_hidden_states , _ ), _ = nnx .scan (
1006+ rematted_scan_fn ,
1007+ length = self .num_layers ,
1008+ in_axes = (nnx .Carry , 0 ),
1009+ out_axes = (nnx .Carry , 0 ),
1010+ transform_metadata = {nnx .PARTITION_NAME : "layers" },
1011+ )(carry , self .transformer_blocks )
1012+ else :
1013+ for block in self .transformer_blocks :
1014+ hidden_states , audio_hidden_states = block (
1015+ hidden_states = hidden_states ,
1016+ audio_hidden_states = audio_hidden_states ,
1017+ encoder_hidden_states = encoder_hidden_states ,
1018+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
1019+ temb = temb ,
1020+ temb_audio = temb_audio ,
1021+ temb_ca_scale_shift = video_cross_attn_scale_shift ,
1022+ temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1023+ temb_ca_gate = video_cross_attn_a2v_gate ,
1024+ temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1025+ video_rotary_emb = video_rotary_emb ,
1026+ audio_rotary_emb = audio_rotary_emb ,
1027+ ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1028+ ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1029+ encoder_attention_mask = encoder_attention_mask ,
1030+ audio_encoder_attention_mask = audio_encoder_attention_mask ,
1031+ )
10271032
10281033 # 6. Output layers
10291034 scale_shift_values = jnp .expand_dims (self .scale_shift_table , axis = (0 , 1 )) + jnp .expand_dims (embedded_timestep , axis = 2 )
0 commit comments