@@ -133,8 +133,7 @@ def test_wan_time_text_embedding(self):
133133 assert timestep_proj .shape == (batch_size , time_proj_dim )
134134 assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
135135
136- @pytest .mark .parametrize ("attention" , ["flash" , "tokamax_flash" ])
137- def test_wan_block (self , attention ):
136+ def test_wan_block (self ):
138137 key = jax .random .key (0 )
139138 rngs = nnx .Rngs (key )
140139 pyconfig .initialize (
@@ -226,24 +225,25 @@ def test_wan_attention(self):
226225 mesh = Mesh (devices_array , config .mesh_axes )
227226 batch_size = 1
228227 query_dim = 5120
229- with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
230- attention = FlaxWanAttention (
231- rngs = rngs ,
232- query_dim = query_dim ,
233- heads = 40 ,
234- dim_head = 128 ,
235- attention_kernel = "flash" ,
236- mesh = mesh ,
237- flash_block_sizes = flash_block_sizes ,
238- )
239- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
228+ for attention_kernel in ["flash" , "tokamax_flash" ]:
229+ with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
230+ attention = FlaxWanAttention (
231+ rngs = rngs ,
232+ query_dim = query_dim ,
233+ heads = 40 ,
234+ dim_head = 128 ,
235+ attention_kernel = attention_kernel ,
236+ mesh = mesh ,
237+ flash_block_sizes = flash_block_sizes ,
238+ )
239+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
240240
241- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
242- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
243- dummy_output = attention (
244- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
245- )
246- assert dummy_output .shape == dummy_hidden_states_shape
241+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
242+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
243+ dummy_output = attention (
244+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
245+ )
246+ assert dummy_output .shape == dummy_hidden_states_shape
247247
248248 # dot product
249249 try :
0 commit comments