@@ -268,16 +268,16 @@ def test_transformer_3d_model_dot_product_attention(self):
268268 out_channels = self .out_channels ,
269269 patch_size = self .patch_size ,
270270 patch_size_t = self .patch_size_t ,
271- num_attention_heads = self . num_attention_heads ,
272- attention_head_dim = self . attention_head_dim ,
273- cross_attention_dim = self . cross_attention_dim ,
271+ num_attention_heads = 8 ,
272+ attention_head_dim = 128 ,
273+ cross_attention_dim = 32 ,
274274 audio_in_channels = self .audio_in_channels ,
275- audio_out_channels = self .audio_out_channels ,
275+ audio_out_channels = self .audio_in_channels ,
276276 audio_patch_size = self .audio_patch_size ,
277277 audio_patch_size_t = self .audio_patch_size_t ,
278- audio_num_attention_heads = self . audio_num_attention_heads ,
279- audio_attention_head_dim = self . audio_attention_head_dim ,
280- audio_cross_attention_dim = self . audio_cross_attention_dim ,
278+ audio_num_attention_heads = 8 ,
279+ audio_attention_head_dim = 128 ,
280+ audio_cross_attention_dim = 32 ,
281281 num_layers = 1 , # Reduced layers for speed
282282 config = self .config ,
283283 scan_layers = False ,
0 commit comments