@@ -181,25 +181,24 @@ def test_wan_block(self):
181181 assert dummy_output .shape == dummy_hidden_states .shape
182182
183183 def test_wan_attention (self ):
184- for attention_kernel in ["flash" , "tokamax_flash" ]:
185- pyconfig .initialize (
186- [
187- None ,
188- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
189- f"attention={ attention_kernel } "
190- ],
191- unittest = True
192- )
193- config = pyconfig .config
194- batch_size = 1
195- channels = 16
196- frames = 21
197- height = 90
198- width = 160
199- hidden_states_shape = (batch_size , frames , height , width , channels )
200- dummy_hidden_states = jnp .ones (hidden_states_shape )
201- wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
202- dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
184+ pyconfig .initialize (
185+ [
186+ None ,
187+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
188+ ],
189+ unittest = True ,
190+ )
191+ config = pyconfig .config
192+
193+ batch_size = 1
194+ channels = 16
195+ frames = 21
196+ height = 90
197+ width = 160
198+ hidden_states_shape = (batch_size , frames , height , width , channels )
199+ dummy_hidden_states = jnp .ones (hidden_states_shape )
200+ wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
201+ dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
203202
204203 key = jax .random .key (0 )
205204 rngs = nnx .Rngs (key )
@@ -425,4 +424,4 @@ def test_quantize_transformer_disabled(self, mock_quantize_model):
425424
426425
427426if __name__ == "__main__" :
428- absltest .main ()
427+ absltest .main ()
0 commit comments