@@ -1089,6 +1089,12 @@ def __call__(
10891089 video_ca_timestep * timestep_cross_attn_gate_scale_factor ,
10901090 hidden_dtype = hidden_states .dtype ,
10911091 )
1092+
1093+ if video_cross_attn_scale_shift .shape [0 ] < batch_size :
1094+ video_cross_attn_scale_shift = jnp .repeat (video_cross_attn_scale_shift , batch_size // video_cross_attn_scale_shift .shape [0 ], axis = 0 )
1095+ if video_cross_attn_a2v_gate .shape [0 ] < batch_size :
1096+ video_cross_attn_a2v_gate = jnp .repeat (video_cross_attn_a2v_gate , batch_size // video_cross_attn_a2v_gate .shape [0 ], axis = 0 )
1097+
10921098 video_cross_attn_scale_shift = video_cross_attn_scale_shift .reshape (
10931099 batch_size , - 1 , video_cross_attn_scale_shift .shape [- 1 ]
10941100 )
@@ -1102,6 +1108,12 @@ def __call__(
11021108 audio_ca_timestep * timestep_cross_attn_gate_scale_factor ,
11031109 hidden_dtype = audio_hidden_states .dtype ,
11041110 )
1111+
1112+ if audio_cross_attn_scale_shift .shape [0 ] < batch_size :
1113+ audio_cross_attn_scale_shift = jnp .repeat (audio_cross_attn_scale_shift , batch_size // audio_cross_attn_scale_shift .shape [0 ], axis = 0 )
1114+ if audio_cross_attn_v2a_gate .shape [0 ] < batch_size :
1115+ audio_cross_attn_v2a_gate = jnp .repeat (audio_cross_attn_v2a_gate , batch_size // audio_cross_attn_v2a_gate .shape [0 ], axis = 0 )
1116+
11051117 audio_cross_attn_scale_shift = audio_cross_attn_scale_shift .reshape (
11061118 batch_size , - 1 , audio_cross_attn_scale_shift .shape [- 1 ]
11071119 )
0 commit comments