@@ -92,8 +92,8 @@ def test_transformer_block_shapes(self):
9292 # Create dummy inputs
9393 hidden_states = jnp .zeros ((self .batch_size , self .seq_len , dim ))
9494 audio_hidden_states = jnp .zeros ((self .batch_size , 128 , audio_dim )) # 128 audio frames for TPFA
95- encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim ))
96- audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim )) # reusing cross_dim for audio context
95+ encoder_hidden_states = jnp .zeros ((self .batch_size , 128 , cross_dim )) # 128 for TPFA
96+ audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 128 , cross_dim )) # reusing cross_dim for audio context
9797
9898 # Dummy scale/shift/gate modulations
9999 # These match the shapes expected by the block internal calculation logic
@@ -225,10 +225,10 @@ def test_transformer_3d_model_instantiation_and_forward(self):
225225
226226 timestep = jnp .array ([1.0 ]) # (B,)
227227
228- encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , 32 )) # (B, Lc, Dc)
229- audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , 32 ))
230- encoder_attention_mask = jnp .ones ((self .batch_size , 5 ), dtype = jnp .float32 )
231- audio_encoder_attention_mask = jnp .ones ((self .batch_size , 5 ), dtype = jnp .float32 )
228+ encoder_hidden_states = jnp .zeros ((self .batch_size , 128 , 32 )) # (B, Lc, Dc) # 128 for TPFA
229+ audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 128 , 32 ))
230+ encoder_attention_mask = jnp .ones ((self .batch_size , 128 ), dtype = jnp .float32 )
231+ audio_encoder_attention_mask = jnp .ones ((self .batch_size , 128 ), dtype = jnp .float32 )
232232
233233 # Forward
234234 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -303,8 +303,8 @@ def test_scan_remat_parity(self):
303303 hidden_states = jnp .ones ((self .batch_size , self .seq_len , self .in_channels )) * 0.5
304304 audio_hidden_states = jnp .ones ((self .batch_size , 128 , self .audio_in_channels )) * 0.5
305305 timestep = jnp .array ([1.0 ])
306- encoder_hidden_states = jnp .ones ((self .batch_size , 5 , 32 )) * 0.1
307- audio_encoder_hidden_states = jnp .ones ((self .batch_size , 5 , 32 )) * 0.1
306+ encoder_hidden_states = jnp .ones ((self .batch_size , 128 , 32 )) * 0.1
307+ audio_encoder_hidden_states = jnp .ones ((self .batch_size , 128 , 32 )) * 0.1
308308
309309 inp_args = dict (
310310 hidden_states = hidden_states ,
0 commit comments