Skip to content

Commit 851a852

Browse files
committed
changes to transformer and test
1 parent 1815411 commit 851a852

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,11 @@ def init_block(rngs):
704704
dim=inner_dim,
705705
num_attention_heads=self.num_attention_heads,
706706
attention_head_dim=self.attention_head_dim,
707-
cross_attention_dim=self.cross_attention_dim,
707+
cross_attention_dim=inner_dim,
708708
audio_dim=audio_inner_dim,
709709
audio_num_attention_heads=self.audio_num_attention_heads,
710710
audio_attention_head_dim=self.audio_attention_head_dim,
711-
audio_cross_attention_dim=self.audio_cross_attention_dim,
711+
audio_cross_attention_dim=audio_inner_dim,
712712
activation_fn=self.activation_fn,
713713
qk_norm=self.qk_norm,
714714
attention_bias=self.attention_bias,
@@ -735,11 +735,11 @@ def init_block(rngs):
735735
dim=inner_dim,
736736
num_attention_heads=self.num_attention_heads,
737737
attention_head_dim=self.attention_head_dim,
738-
cross_attention_dim=self.cross_attention_dim,
738+
cross_attention_dim=inner_dim,
739739
audio_dim=audio_inner_dim,
740740
audio_num_attention_heads=self.audio_num_attention_heads,
741741
audio_attention_head_dim=self.audio_attention_head_dim,
742-
audio_cross_attention_dim=self.audio_cross_attention_dim,
742+
audio_cross_attention_dim=audio_inner_dim,
743743
activation_fn=self.activation_fn,
744744
qk_norm=self.qk_norm,
745745
attention_bias=self.attention_bias,

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_transformer_block_shapes(self):
9191

9292
# Create dummy inputs
9393
hidden_states = jnp.zeros((self.batch_size, self.seq_len, dim))
94-
audio_hidden_states = jnp.zeros((self.batch_size, 10, audio_dim)) # 10 audio frames
94+
audio_hidden_states = jnp.zeros((self.batch_size, 128, audio_dim)) # 128 audio frames for TPFA
9595
encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim))
9696
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim)) # reusing cross_dim for audio context
9797

@@ -221,7 +221,7 @@ def test_transformer_3d_model_instantiation_and_forward(self):
221221

222222
# Let's pass (B, L, C).
223223
hidden_states = jnp.zeros((self.batch_size, self.seq_len, self.in_channels))
224-
audio_hidden_states = jnp.zeros((self.batch_size, 10, self.audio_in_channels))
224+
audio_hidden_states = jnp.zeros((self.batch_size, 128, self.audio_in_channels))
225225

226226
timestep = jnp.array([1.0]) # (B,)
227227

@@ -241,7 +241,9 @@ def test_transformer_3d_model_instantiation_and_forward(self):
241241
num_frames=self.num_frames,
242242
height=self.height,
243243
width=self.width,
244-
audio_num_frames=10,
244+
width=self.width,
245+
audio_num_frames=128,
246+
fps=24.0,
245247
fps=24.0,
246248
return_dict=True,
247249
encoder_attention_mask=encoder_attention_mask,
@@ -255,7 +257,8 @@ def test_transformer_3d_model_instantiation_and_forward(self):
255257
print(f"Model Output Audio Shape: {audio_sample.shape}")
256258

257259
self.assertEqual(sample.shape, (self.batch_size, self.seq_len, self.out_channels))
258-
self.assertEqual(audio_sample.shape, (self.batch_size, 10, self.audio_in_channels))
260+
self.assertEqual(sample.shape, (self.batch_size, self.seq_len, self.out_channels))
261+
self.assertEqual(audio_sample.shape, (self.batch_size, 128, self.audio_in_channels))
259262

260263
def test_scan_remat_parity(self):
261264
"""
@@ -300,7 +303,7 @@ def test_scan_remat_parity(self):
300303

301304
# Inputs
302305
hidden_states = jnp.ones((self.batch_size, self.seq_len, self.in_channels)) * 0.5
303-
audio_hidden_states = jnp.ones((self.batch_size, 10, self.audio_in_channels)) * 0.5
306+
audio_hidden_states = jnp.ones((self.batch_size, 128, self.audio_in_channels)) * 0.5
304307
timestep = jnp.array([1.0])
305308
encoder_hidden_states = jnp.ones((self.batch_size, 5, 32)) * 0.1
306309
audio_encoder_hidden_states = jnp.ones((self.batch_size, 5, 32)) * 0.1
@@ -314,7 +317,10 @@ def test_scan_remat_parity(self):
314317
num_frames=self.num_frames,
315318
height=self.height,
316319
width=self.width,
317-
audio_num_frames=10,
320+
height=self.height,
321+
width=self.width,
322+
audio_num_frames=128,
323+
fps=24.0,
318324
fps=24.0,
319325
return_dict=True,
320326
)

0 commit comments

Comments
 (0)