2222from unittest .mock import Mock , patch , call
2323from absl .testing import absltest
2424from flax import nnx
25+ from flax .linen import partitioning as nn_partitioning
2526from jax .sharding import Mesh
2627
2728from .. import pyconfig
@@ -163,43 +164,41 @@ def test_wan_block(self):
163164 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
164165
165166 dummy_temb = jnp .ones ((batch_size , 6 , dim ))
166-
167- wan_block = WanTransformerBlock (
168- rngs = rngs ,
169- dim = dim ,
170- ffn_dim = ffn_dim ,
171- num_heads = num_heads ,
172- qk_norm = qk_norm ,
173- cross_attn_norm = cross_attn_norm ,
174- eps = eps ,
175- attention = "flash" ,
176- mesh = mesh ,
177- flash_block_sizes = flash_block_sizes ,
178- )
179- with mesh :
167+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
168+ wan_block = WanTransformerBlock (
169+ rngs = rngs ,
170+ dim = dim ,
171+ ffn_dim = ffn_dim ,
172+ num_heads = num_heads ,
173+ qk_norm = qk_norm ,
174+ cross_attn_norm = cross_attn_norm ,
175+ eps = eps ,
176+ attention = "flash" ,
177+ mesh = mesh ,
178+ flash_block_sizes = flash_block_sizes ,
179+ )
180180 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
181181 assert dummy_output .shape == dummy_hidden_states .shape
182182
183183 def test_wan_attention (self ):
184- for attention_kernel in ["flash" , "tokamax_flash" ]:
185- pyconfig .initialize (
186- [
187- None ,
188- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
189- f"attention={ attention_kernel } "
190- ],
191- unittest = True
192- )
193- config = pyconfig .config
194- batch_size = 1
195- channels = 16
196- frames = 21
197- height = 90
198- width = 160
199- hidden_states_shape = (batch_size , frames , height , width , channels )
200- dummy_hidden_states = jnp .ones (hidden_states_shape )
201- wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
202- dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
184+ pyconfig .initialize (
185+ [
186+ None ,
187+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
188+ ],
189+ unittest = True ,
190+ )
191+ config = pyconfig .config
192+
193+ batch_size = 1
194+ channels = 16
195+ frames = 21
196+ height = 90
197+ width = 160
198+ hidden_states_shape = (batch_size , frames , height , width , channels )
199+ dummy_hidden_states = jnp .ones (hidden_states_shape )
200+ wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
201+ dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
203202
204203 key = jax .random .key (0 )
205204 rngs = nnx .Rngs (key )
@@ -210,40 +209,39 @@ def test_wan_attention(self):
210209 mesh = Mesh (devices_array , config .mesh_axes )
211210 batch_size = 1
212211 query_dim = 5120
213- attention = FlaxWanAttention (
214- rngs = rngs ,
215- query_dim = query_dim ,
216- heads = 40 ,
217- dim_head = 128 ,
218- attention_kernel = "flash" ,
219- mesh = mesh ,
220- flash_block_sizes = flash_block_sizes ,
221- )
222-
223- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
224-
225- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
226- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
227- with mesh :
228- dummy_output = attention (
229- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
230- )
231- assert dummy_output .shape == dummy_hidden_states_shape
232-
233- # dot product
234- try :
212+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
235213 attention = FlaxWanAttention (
236214 rngs = rngs ,
237215 query_dim = query_dim ,
238216 heads = 40 ,
239217 dim_head = 128 ,
240- attention_kernel = "dot_product" ,
241- split_head_dim = True ,
218+ attention_kernel = "flash" ,
242219 mesh = mesh ,
243220 flash_block_sizes = flash_block_sizes ,
244221 )
245- except NotImplementedError :
246- pass
222+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
223+
224+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
225+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
226+ dummy_output = attention (
227+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
228+ )
229+ assert dummy_output .shape == dummy_hidden_states_shape
230+
231+ # dot product
232+ try :
233+ attention = FlaxWanAttention (
234+ rngs = rngs ,
235+ query_dim = query_dim ,
236+ heads = 40 ,
237+ dim_head = 128 ,
238+ attention_kernel = "dot_product" ,
239+ split_head_dim = True ,
240+ mesh = mesh ,
241+ flash_block_sizes = flash_block_sizes ,
242+ )
243+ except NotImplementedError :
244+ pass
247245
248246 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
249247 def test_wan_model (self ):
0 commit comments