@@ -221,19 +221,13 @@ def test_wan_upsample(self):
221221 in_depth , in_height , in_width = 10 , 32 , 32
222222 in_channels = 3
223223
224- dummy_input = jnp .ones ((batch_size , in_depth , in_height , in_width , in_channels ))
224+ dummy_input = jnp .ones ((batch_size * in_depth , in_height , in_width , in_channels ))
225225
226226 upsample = WanUpsample (scale_factor = (2.0 , 2.0 ))
227227
228228 # --- Test Case 1: depth > 1 ---
229229 output = upsample (dummy_input )
230- assert output .shape == (1 , 10 , 64 , 64 , 3 )
231-
232- in_depth = 1
233- dummy_input = jnp .ones ((batch_size , in_depth , in_height , in_width , in_channels ))
234- # --- Test Case 1: depth == 1 ---
235- output = upsample (dummy_input )
236- assert output .shape == (1 , 1 , 64 , 64 , 3 )
230+ assert output .shape == (10 , 64 , 64 , 3 )
237231
238232 def test_wan_resample (self ):
239233 # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity
@@ -260,13 +254,7 @@ def test_wan_resample(self):
260254 input_shape = (batch , t , h , w , dim )
261255 dummy_input = jnp .ones (input_shape )
262256 output = wan_resample (dummy_input )
263- assert output .shape == (batch , t , h // 2 , h // 2 , dim )
264- breakpoint ()
265-
266- # --- Test Case 1: downsample3d ---
267- dim = 192
268- input_shape = (1 , dim , 1 , 240 , 360 )
269- torch_wan_resample = WanResample (dim = dim , mode = "downsample3d" )
257+ assert output .shape == (batch , t , h // 2 , w // 2 , dim )
270258
271259 def test_3d_conv (self ):
272260 key = jax .random .key (0 )
0 commit comments