@@ -77,12 +77,12 @@ def test_transformer_block_shapes(self):
7777 block = LTX2VideoTransformerBlock (
7878 rngs = self .rngs ,
7979 dim = dim ,
80- num_attention_heads = 4 ,
81- attention_head_dim = 8 ,
80+ num_attention_heads = 8 ,
81+ attention_head_dim = 4 ,
8282 cross_attention_dim = cross_dim ,
8383 audio_dim = audio_dim ,
84- audio_num_attention_heads = 4 ,
85- audio_attention_head_dim = 4 ,
84+ audio_num_attention_heads = 8 ,
85+ audio_attention_head_dim = 2 ,
8686 audio_cross_attention_dim = cross_dim ,
8787 activation_fn = "gelu" ,
8888 qk_norm = "rms_norm_across_heads" ,
@@ -180,15 +180,15 @@ def test_transformer_3d_model_instantiation_and_forward(self):
180180 out_channels = self .out_channels ,
181181 patch_size = self .patch_size ,
182182 patch_size_t = self .patch_size_t ,
183- num_attention_heads = 2 ,
184- attention_head_dim = 16 ,
183+ num_attention_heads = 8 ,
184+ attention_head_dim = 4 ,
185185 num_layers = 1 , # 1 layer for speed
186186 caption_channels = 32 , # small for test
187187 cross_attention_dim = 32 ,
188188 audio_in_channels = self .audio_in_channels ,
189189 audio_out_channels = self .audio_in_channels ,
190- audio_num_attention_heads = 2 ,
191- audio_attention_head_dim = 16 ,
190+ audio_num_attention_heads = 8 ,
191+ audio_attention_head_dim = 2 ,
192192 audio_cross_attention_dim = 32 ,
193193 mesh = self .mesh ,
194194 )
@@ -271,15 +271,15 @@ def test_scan_remat_parity(self):
271271 out_channels = self .out_channels ,
272272 patch_size = self .patch_size ,
273273 patch_size_t = self .patch_size_t ,
274- num_attention_heads = 2 ,
275- attention_head_dim = 16 ,
274+ num_attention_heads = 8 ,
275+ attention_head_dim = 4 ,
276276 num_layers = 2 , # Need >1 layer to test scan effectively
277277 caption_channels = 32 ,
278278 cross_attention_dim = 32 ,
279279 audio_in_channels = self .audio_in_channels ,
280280 audio_out_channels = self .audio_in_channels ,
281- audio_num_attention_heads = 2 ,
282- audio_attention_head_dim = 16 ,
281+ audio_num_attention_heads = 8 ,
282+ audio_attention_head_dim = 4 ,
283283 audio_cross_attention_dim = 32
284284 )
285285
0 commit comments