Skip to content

Commit 1815411

Browse files
committed
dim changed in test
1 parent d639254 commit 1815411

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,21 @@ def test_transformer_block_shapes(self):
6868
"""
6969
print("\n=== Testing LTX2VideoTransformerBlock Shapes ===")
7070

71-
dim = 32
72-
audio_dim = 16
73-
cross_dim = 20 # context dim
71+
dim = 1024
72+
audio_dim = 1024
73+
cross_dim = 64 # context dim
7474

7575
# NNX sharding context
7676
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
7777
block = LTX2VideoTransformerBlock(
7878
rngs=self.rngs,
7979
dim=dim,
8080
num_attention_heads=8,
81-
attention_head_dim=4,
81+
attention_head_dim=128,
8282
cross_attention_dim=cross_dim,
8383
audio_dim=audio_dim,
8484
audio_num_attention_heads=8,
85-
audio_attention_head_dim=2,
85+
audio_attention_head_dim=128,
8686
audio_cross_attention_dim=cross_dim,
8787
activation_fn="gelu",
8888
qk_norm="rms_norm_across_heads",
@@ -181,14 +181,14 @@ def test_transformer_3d_model_instantiation_and_forward(self):
181181
patch_size=self.patch_size,
182182
patch_size_t=self.patch_size_t,
183183
num_attention_heads=8,
184-
attention_head_dim=4,
184+
attention_head_dim=128,
185185
num_layers=1, # 1 layer for speed
186186
caption_channels=32, # small for test
187187
cross_attention_dim=32,
188188
audio_in_channels=self.audio_in_channels,
189189
audio_out_channels= self.audio_in_channels,
190190
audio_num_attention_heads=8,
191-
audio_attention_head_dim=2,
191+
audio_attention_head_dim=128,
192192
audio_cross_attention_dim=32,
193193
mesh=self.mesh,
194194
)

0 commit comments

Comments
 (0)