Skip to content

Commit bee57ba

Browse files
update transformer test.
1 parent fcb1ab1 commit bee57ba

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_wan_model(self):
248248

249249
batch_size = 1
250250
channels = 16
251-
frames = 21
251+
frames = 1
252252
height = 90
253253
width = 160
254254
hidden_states_shape = (batch_size, channels, frames, height, width)
@@ -262,12 +262,8 @@ def test_wan_model(self):
262262

263263
mesh = Mesh(devices_array, config.mesh_axes)
264264
batch_size = 1
265-
wan_model = WanModel(
266-
rngs=rngs,
267-
attention="flash",
268-
mesh=mesh,
269-
flash_block_sizes=flash_block_sizes,
270-
)
265+
num_layers = 1
266+
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
271267

272268
dummy_timestep = jnp.ones((batch_size))
273269
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))
@@ -277,7 +273,7 @@ def test_wan_model(self):
277273
timestep=dummy_timestep,
278274
encoder_hidden_states=dummy_encoder_hidden_states,
279275
is_uncond=jnp.array(True, dtype=jnp.bool_),
280-
slg_mask=jnp.zeros(40, dtype=jnp.bool_),
276+
slg_mask=jnp.zeros(num_layers, dtype=jnp.bool_),
281277
)
282278
assert dummy_output.shape == hidden_states_shape
283279

0 commit comments

Comments
 (0)