Skip to content

Commit b1385d2

Browse files
committed
Change in wan_vae_test.py
1 parent d616736 commit b1385d2

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_wan_resample(self):
266266
# channels is always last here
267267
input_shape = (batch, t, h, w, dim)
268268
dummy_input = jnp.ones(input_shape)
269-
output = wan_resample(dummy_input)
269+
output, _, _ = wan_resample(dummy_input)
270270
assert output.shape == (batch, t, h // 2, w // 2, dim)
271271

272272
def test_3d_conv(self):
@@ -347,7 +347,7 @@ def test_wan_residual(self):
347347
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
348348
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh)
349349
dummy_input = jnp.ones(input_shape)
350-
dummy_output = wan_residual_block(dummy_input)
350+
dummy_output, _, _ = wan_residual_block(dummy_input)
351351
assert dummy_output.shape == expected_output_shape
352352
# --- Test Case 1: different in/out dim ---
353353
in_dim = 96
@@ -356,7 +356,7 @@ def test_wan_residual(self):
356356

357357
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh)
358358
dummy_input = jnp.ones(input_shape)
359-
dummy_output = wan_residual_block(dummy_input)
359+
dummy_output, _, _ = wan_residual_block(dummy_input)
360360
assert dummy_output.shape == expected_output_shape
361361

362362
def test_wan_attention(self):
@@ -371,7 +371,7 @@ def test_wan_attention(self):
371371
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
372372
wan_attention = WanAttentionBlock(dim=dim, rngs=rngs)
373373
dummy_input = jnp.ones(input_shape)
374-
output = wan_attention(dummy_input)
374+
output, _, _ = wan_attention(dummy_input)
375375
assert output.shape == input_shape
376376

377377
def test_wan_midblock(self):
@@ -396,7 +396,7 @@ def test_wan_midblock(self):
396396
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
397397
wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh)
398398
dummy_input = jnp.ones(input_shape)
399-
output = wan_midblock(dummy_input)
399+
output, _, _ = wan_midblock(dummy_input)
400400
assert output.shape == input_shape
401401

402402
def test_wan_decode(self):

0 commit comments

Comments
 (0)