@@ -91,7 +91,7 @@ def test_transformer_block_shapes(self):
9191
9292 # Create dummy inputs
9393 hidden_states = jnp .zeros ((self .batch_size , self .seq_len , dim ))
94- audio_hidden_states = jnp .zeros ((self .batch_size , 10 , audio_dim )) # 10 audio frames
94+ audio_hidden_states = jnp .zeros ((self .batch_size , 128 , audio_dim )) # 128 audio frames for TPFA
9595 encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim ))
9696 audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim )) # reusing cross_dim for audio context
9797
@@ -221,7 +221,7 @@ def test_transformer_3d_model_instantiation_and_forward(self):
221221
222222 # Let's pass (B, L, C).
223223 hidden_states = jnp .zeros ((self .batch_size , self .seq_len , self .in_channels ))
224- audio_hidden_states = jnp .zeros ((self .batch_size , 10 , self .audio_in_channels ))
224+ audio_hidden_states = jnp .zeros ((self .batch_size , 128 , self .audio_in_channels ))
225225
226226 timestep = jnp .array ([1.0 ]) # (B,)
227227
@@ -241,7 +241,9 @@ def test_transformer_3d_model_instantiation_and_forward(self):
241241 num_frames = self .num_frames ,
242242 height = self .height ,
243243 width = self .width ,
244- audio_num_frames = 10 ,
244+ width = self .width ,
245+ audio_num_frames = 128 ,
246+ fps = 24.0 ,
245247 fps = 24.0 ,
246248 return_dict = True ,
247249 encoder_attention_mask = encoder_attention_mask ,
@@ -255,7 +257,8 @@ def test_transformer_3d_model_instantiation_and_forward(self):
255257 print (f"Model Output Audio Shape: { audio_sample .shape } " )
256258
257259 self .assertEqual (sample .shape , (self .batch_size , self .seq_len , self .out_channels ))
258- self .assertEqual (audio_sample .shape , (self .batch_size , 10 , self .audio_in_channels ))
260+ self .assertEqual (sample .shape , (self .batch_size , self .seq_len , self .out_channels ))
261+ self .assertEqual (audio_sample .shape , (self .batch_size , 128 , self .audio_in_channels ))
259262
260263 def test_scan_remat_parity (self ):
261264 """
@@ -300,7 +303,7 @@ def test_scan_remat_parity(self):
300303
301304 # Inputs
302305 hidden_states = jnp .ones ((self .batch_size , self .seq_len , self .in_channels )) * 0.5
303- audio_hidden_states = jnp .ones ((self .batch_size , 10 , self .audio_in_channels )) * 0.5
306+ audio_hidden_states = jnp .ones ((self .batch_size , 128 , self .audio_in_channels )) * 0.5
304307 timestep = jnp .array ([1.0 ])
305308 encoder_hidden_states = jnp .ones ((self .batch_size , 5 , 32 )) * 0.1
306309 audio_encoder_hidden_states = jnp .ones ((self .batch_size , 5 , 32 )) * 0.1
@@ -314,7 +317,10 @@ def test_scan_remat_parity(self):
314317 num_frames = self .num_frames ,
315318 height = self .height ,
316319 width = self .width ,
317- audio_num_frames = 10 ,
320+ height = self .height ,
321+ width = self .width ,
322+ audio_num_frames = 128 ,
323+ fps = 24.0 ,
318324 fps = 24.0 ,
319325 return_dict = True ,
320326 )
0 commit comments