@@ -184,12 +184,12 @@ def test_transformer_3d_model_instantiation_and_forward(self):
184184 attention_head_dim = 128 ,
185185 num_layers = 1 , # 1 layer for speed
186186 caption_channels = 32 , # small for test
187- cross_attention_dim = 32 ,
187+ cross_attention_dim = 1024 ,
188188 audio_in_channels = self .audio_in_channels ,
189189 audio_out_channels = self .audio_in_channels ,
190190 audio_num_attention_heads = 8 ,
191191 audio_attention_head_dim = 128 ,
192- audio_cross_attention_dim = 32 ,
192+ audio_cross_attention_dim = 1024 ,
193193 mesh = self .mesh ,
194194 )
195195
@@ -270,15 +270,15 @@ def test_transformer_3d_model_dot_product_attention(self):
270270 patch_size_t = self .patch_size_t ,
271271 num_attention_heads = 8 ,
272272 attention_head_dim = 128 ,
273- cross_attention_dim = 32 ,
273+ cross_attention_dim = 1024 ,
274274 caption_channels = 32 ,
275275 audio_in_channels = self .audio_in_channels ,
276276 audio_out_channels = self .audio_in_channels ,
277277 audio_patch_size = 1 ,
278278 audio_patch_size_t = 1 ,
279279 audio_num_attention_heads = 8 ,
280280 audio_attention_head_dim = 128 ,
281- audio_cross_attention_dim = 32 ,
281+ audio_cross_attention_dim = 1024 ,
282282 num_layers = 1 , # Reduced layers for speed
283283 scan_layers = False ,
284284 mesh = self .mesh ,
0 commit comments