@@ -1026,24 +1026,24 @@ def scan_fn(carry, block):
10261026 )(carry , self .transformer_blocks )
10271027 else :
10281028 for block in self .transformer_blocks :
1029- hidden_states , audio_hidden_states = block (
1030- hidden_states = hidden_states ,
1031- audio_hidden_states = audio_hidden_states ,
1032- encoder_hidden_states = encoder_hidden_states ,
1033- audio_encoder_hidden_states = audio_encoder_hidden_states ,
1034- temb = temb ,
1035- temb_audio = temb_audio ,
1036- temb_ca_scale_shift = video_cross_attn_scale_shift ,
1037- temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1038- temb_ca_gate = video_cross_attn_a2v_gate ,
1039- temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1040- video_rotary_emb = video_rotary_emb ,
1041- audio_rotary_emb = audio_rotary_emb ,
1042- ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1043- ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1044- encoder_attention_mask = encoder_attention_mask ,
1045- audio_encoder_attention_mask = audio_encoder_attention_mask ,
1046- )
1029+ hidden_states , audio_hidden_states = block (
1030+ hidden_states = hidden_states ,
1031+ audio_hidden_states = audio_hidden_states ,
1032+ encoder_hidden_states = encoder_hidden_states ,
1033+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
1034+ temb = temb ,
1035+ temb_audio = temb_audio ,
1036+ temb_ca_scale_shift = video_cross_attn_scale_shift ,
1037+ temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1038+ temb_ca_gate = video_cross_attn_a2v_gate ,
1039+ temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1040+ video_rotary_emb = video_rotary_emb ,
1041+ audio_rotary_emb = audio_rotary_emb ,
1042+ ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1043+ ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1044+ encoder_attention_mask = encoder_attention_mask ,
1045+ audio_encoder_attention_mask = audio_encoder_attention_mask ,
1046+ )
10471047
10481048 # 6. Output layers
10491049 with jax .named_scope ("Output Projection & Norm" ):
0 commit comments