@@ -68,21 +68,21 @@ def test_transformer_block_shapes(self):
6868 """
6969 print ("\n === Testing LTX2VideoTransformerBlock Shapes ===" )
7070
71- dim = 32
72- audio_dim = 16
73- cross_dim = 20 # context dim
71+ dim = 1024
72+ audio_dim = 1024
73+ cross_dim = 64 # context dim
7474
7575 # NNX sharding context
7676 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
7777 block = LTX2VideoTransformerBlock (
7878 rngs = self .rngs ,
7979 dim = dim ,
8080 num_attention_heads = 8 ,
81- attention_head_dim = 4 ,
81+ attention_head_dim = 128 ,
8282 cross_attention_dim = cross_dim ,
8383 audio_dim = audio_dim ,
8484 audio_num_attention_heads = 8 ,
85- audio_attention_head_dim = 2 ,
85+ audio_attention_head_dim = 128 ,
8686 audio_cross_attention_dim = cross_dim ,
8787 activation_fn = "gelu" ,
8888 qk_norm = "rms_norm_across_heads" ,
@@ -181,14 +181,14 @@ def test_transformer_3d_model_instantiation_and_forward(self):
181181 patch_size = self .patch_size ,
182182 patch_size_t = self .patch_size_t ,
183183 num_attention_heads = 8 ,
184- attention_head_dim = 4 ,
184+ attention_head_dim = 128 ,
185185 num_layers = 1 , # 1 layer for speed
186186 caption_channels = 32 , # small for test
187187 cross_attention_dim = 32 ,
188188 audio_in_channels = self .audio_in_channels ,
189189 audio_out_channels = self .audio_in_channels ,
190190 audio_num_attention_heads = 8 ,
191- audio_attention_head_dim = 2 ,
191+ audio_attention_head_dim = 128 ,
192192 audio_cross_attention_dim = 32 ,
193193 mesh = self .mesh ,
194194 )
0 commit comments