Skip to content

Commit a25129b

Browse files
committed
fix
1 parent 9187ccf commit a25129b

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def test_transformer_block_shapes(self):
9292
# Create dummy inputs
9393
hidden_states = jnp.zeros((self.batch_size, self.seq_len, dim))
9494
audio_hidden_states = jnp.zeros((self.batch_size, 128, audio_dim)) # 128 audio frames for TPFA
95-
encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim))
96-
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim)) # reusing cross_dim for audio context
95+
encoder_hidden_states = jnp.zeros((self.batch_size, 128, cross_dim)) # 128 for TPFA
96+
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 128, cross_dim)) # reusing cross_dim for audio context
9797

9898
# Dummy scale/shift/gate modulations
9999
# These match the shapes expected by the block internal calculation logic
@@ -225,10 +225,10 @@ def test_transformer_3d_model_instantiation_and_forward(self):
225225

226226
timestep = jnp.array([1.0]) # (B,)
227227

228-
encoder_hidden_states = jnp.zeros((self.batch_size, 5, 32)) # (B, Lc, Dc)
229-
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, 32))
230-
encoder_attention_mask = jnp.ones((self.batch_size, 5), dtype=jnp.float32)
231-
audio_encoder_attention_mask = jnp.ones((self.batch_size, 5), dtype=jnp.float32)
228+
encoder_hidden_states = jnp.zeros((self.batch_size, 128, 32)) # (B, Lc, Dc) # 128 for TPFA
229+
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 128, 32))
230+
encoder_attention_mask = jnp.ones((self.batch_size, 128), dtype=jnp.float32)
231+
audio_encoder_attention_mask = jnp.ones((self.batch_size, 128), dtype=jnp.float32)
232232

233233
# Forward
234234
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -303,8 +303,8 @@ def test_scan_remat_parity(self):
303303
hidden_states = jnp.ones((self.batch_size, self.seq_len, self.in_channels)) * 0.5
304304
audio_hidden_states = jnp.ones((self.batch_size, 128, self.audio_in_channels)) * 0.5
305305
timestep = jnp.array([1.0])
306-
encoder_hidden_states = jnp.ones((self.batch_size, 5, 32)) * 0.1
307-
audio_encoder_hidden_states = jnp.ones((self.batch_size, 5, 32)) * 0.1
306+
encoder_hidden_states = jnp.ones((self.batch_size, 128, 32)) * 0.1
307+
audio_encoder_hidden_states = jnp.ones((self.batch_size, 128, 32)) * 0.1
308308

309309
inp_args = dict(
310310
hidden_states=hidden_states,

0 commit comments

Comments
 (0)