@@ -1009,29 +1009,29 @@ def __call__(
10091009 with self .conditional_named_scope ("transformer_block" ):
10101010 def scan_fn (carry , block ):
10111011 hidden_states , audio_hidden_states , rngs_carry = carry
1012- hidden_states_out , audio_hidden_states_out = block (
1013- hidden_states = hidden_states ,
1014- audio_hidden_states = audio_hidden_states ,
1015- encoder_hidden_states = encoder_hidden_states ,
1016- audio_encoder_hidden_states = audio_encoder_hidden_states ,
1017- temb = temb ,
1018- temb_audio = temb_audio ,
1019- temb_ca_scale_shift = video_cross_attn_scale_shift ,
1020- temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1021- temb_ca_gate = video_cross_attn_a2v_gate ,
1022- temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1023- video_rotary_emb = video_rotary_emb ,
1024- audio_rotary_emb = audio_rotary_emb ,
1025- ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1026- ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1027- encoder_attention_mask = encoder_attention_mask ,
1028- audio_encoder_attention_mask = audio_encoder_attention_mask ,
1029- )
1030- return (
1031- hidden_states_out .astype (hidden_states .dtype ),
1032- audio_hidden_states_out .astype (audio_hidden_states .dtype ),
1033- rngs_carry ,
1034- ), None
1012+ hidden_states_out , audio_hidden_states_out = block (
1013+ hidden_states = hidden_states ,
1014+ audio_hidden_states = audio_hidden_states ,
1015+ encoder_hidden_states = encoder_hidden_states ,
1016+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
1017+ temb = temb ,
1018+ temb_audio = temb_audio ,
1019+ temb_ca_scale_shift = video_cross_attn_scale_shift ,
1020+ temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1021+ temb_ca_gate = video_cross_attn_a2v_gate ,
1022+ temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1023+ video_rotary_emb = video_rotary_emb ,
1024+ audio_rotary_emb = audio_rotary_emb ,
1025+ ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1026+ ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1027+ encoder_attention_mask = encoder_attention_mask ,
1028+ audio_encoder_attention_mask = audio_encoder_attention_mask ,
1029+ )
1030+ return (
1031+ hidden_states_out .astype (hidden_states .dtype ),
1032+ audio_hidden_states_out .astype (audio_hidden_states .dtype ),
1033+ rngs_carry ,
1034+ ), None
10351035
10361036 if self .scan_layers :
10371037 rematted_scan_fn = self .gradient_checkpoint .apply (
0 commit comments