Skip to content

Commit 721b4c6

Browse files
committed
Parametrize test
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 436e7d1 commit 721b4c6

1 file changed

Lines changed: 19 additions & 19 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ def test_wan_time_text_embedding(self):
133133
assert timestep_proj.shape == (batch_size, time_proj_dim)
134134
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
135135

136-
@pytest.mark.parametrize("attention", ["flash", "tokamax_flash"])
137-
def test_wan_block(self, attention):
136+
def test_wan_block(self):
138137
key = jax.random.key(0)
139138
rngs = nnx.Rngs(key)
140139
pyconfig.initialize(
@@ -226,24 +225,25 @@ def test_wan_attention(self):
226225
mesh = Mesh(devices_array, config.mesh_axes)
227226
batch_size = 1
228227
query_dim = 5120
229-
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
230-
attention = FlaxWanAttention(
231-
rngs=rngs,
232-
query_dim=query_dim,
233-
heads=40,
234-
dim_head=128,
235-
attention_kernel="flash",
236-
mesh=mesh,
237-
flash_block_sizes=flash_block_sizes,
238-
)
239-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
228+
for attention_kernel in ["flash", "tokamax_flash"]:
229+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
230+
attention = FlaxWanAttention(
231+
rngs=rngs,
232+
query_dim=query_dim,
233+
heads=40,
234+
dim_head=128,
235+
attention_kernel=attention_kernel,
236+
mesh=mesh,
237+
flash_block_sizes=flash_block_sizes,
238+
)
239+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
240240

241-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
242-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
243-
dummy_output = attention(
244-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
245-
)
246-
assert dummy_output.shape == dummy_hidden_states_shape
241+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
242+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
243+
dummy_output = attention(
244+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
245+
)
246+
assert dummy_output.shape == dummy_hidden_states_shape
247247

248248
# dot product
249249
try:

0 commit comments

Comments
 (0)