Skip to content

Commit 089f8ac

Browse files
fix unit tests
1 parent cd16f28 commit 089f8ac

1 file changed

Lines changed: 3 additions & 15 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,13 @@ def test_wan_upsample(self):
221221
in_depth, in_height, in_width = 10, 32, 32
222222
in_channels = 3
223223

224-
dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels))
224+
dummy_input = jnp.ones((batch_size * in_depth, in_height, in_width, in_channels))
225225

226226
upsample = WanUpsample(scale_factor=(2.0, 2.0))
227227

228228
# --- Test Case 1: depth > 1 ---
229229
output = upsample(dummy_input)
230-
assert output.shape == (1, 10, 64, 64, 3)
231-
232-
in_depth = 1
233-
dummy_input = jnp.ones((batch_size, in_depth, in_height, in_width, in_channels))
234-
# --- Test Case 1: depth == 1 ---
235-
output = upsample(dummy_input)
236-
assert output.shape == (1, 1, 64, 64, 3)
230+
assert output.shape == (10, 64, 64, 3)
237231

238232
def test_wan_resample(self):
239233
# TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity
@@ -260,13 +254,7 @@ def test_wan_resample(self):
260254
input_shape = (batch, t, h, w, dim)
261255
dummy_input = jnp.ones(input_shape)
262256
output = wan_resample(dummy_input)
263-
assert output.shape == (batch, t, h // 2, h // 2, dim)
264-
breakpoint()
265-
266-
# --- Test Case 1: downsample3d ---
267-
dim = 192
268-
input_shape = (1, dim, 1, 240, 360)
269-
torch_wan_resample = WanResample(dim=dim, mode="downsample3d")
257+
assert output.shape == (batch, t, h // 2, w // 2, dim)
270258

271259
def test_3d_conv(self):
272260
key = jax.random.key(0)

0 commit comments

Comments
 (0)