@@ -266,7 +266,7 @@ def test_wan_resample(self):
266266 # channels is always last here
267267 input_shape = (batch , t , h , w , dim )
268268 dummy_input = jnp .ones (input_shape )
269- output = wan_resample (dummy_input )
269+ output , _ , _ = wan_resample (dummy_input )
270270 assert output .shape == (batch , t , h // 2 , w // 2 , dim )
271271
272272 def test_3d_conv (self ):
@@ -347,7 +347,7 @@ def test_wan_residual(self):
347347 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
348348 wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = mesh )
349349 dummy_input = jnp .ones (input_shape )
350- dummy_output = wan_residual_block (dummy_input )
350+ dummy_output , _ , _ = wan_residual_block (dummy_input )
351351 assert dummy_output .shape == expected_output_shape
352352 # --- Test Case 1: different in/out dim ---
353353 in_dim = 96
@@ -356,7 +356,7 @@ def test_wan_residual(self):
356356
357357 wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = mesh )
358358 dummy_input = jnp .ones (input_shape )
359- dummy_output = wan_residual_block (dummy_input )
359+ dummy_output , _ , _ = wan_residual_block (dummy_input )
360360 assert dummy_output .shape == expected_output_shape
361361
362362 def test_wan_attention (self ):
@@ -371,7 +371,7 @@ def test_wan_attention(self):
371371 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
372372 wan_attention = WanAttentionBlock (dim = dim , rngs = rngs )
373373 dummy_input = jnp .ones (input_shape )
374- output = wan_attention (dummy_input )
374+ output , _ , _ = wan_attention (dummy_input )
375375 assert output .shape == input_shape
376376
377377 def test_wan_midblock (self ):
@@ -396,7 +396,7 @@ def test_wan_midblock(self):
396396 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
397397 wan_midblock = WanMidBlock (dim = dim , rngs = rngs , mesh = mesh )
398398 dummy_input = jnp .ones (input_shape )
399- output = wan_midblock (dummy_input )
399+ output , _ , _ = wan_midblock (dummy_input )
400400 assert output .shape == input_shape
401401
402402 def test_wan_decode (self ):
0 commit comments