3939from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
4040import qwix
4141import numpy as np
42+ from flax .linen import partitioning as nn_partitioning
4243
4344RealQtRule = qwix .QtRule
4445
@@ -52,6 +53,17 @@ class WanTransformerTest(unittest.TestCase):
5253
5354 def setUp (self ):
5455 WanTransformerTest .dummy_data = {}
56+ pyconfig .initialize (
57+ [
58+ None ,
59+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
60+ ],
61+ unittest = True ,
62+ )
63+ self .config = pyconfig .config
64+ devices_array = create_device_mesh (self .config )
65+ self .mesh = Mesh (devices_array , self .config .mesh_axes )
66+
5567
5668 def test_rotary_pos_embed (self ):
5769 batch_size = 1
@@ -69,24 +81,18 @@ def test_nnx_pixart_alpha_text_projection(self):
6981 key = jax .random .key (0 )
7082 rngs = nnx .Rngs (key )
7183 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
72- num_devices = len (jax .devices ())
73- device_mesh = np .array (jax .devices ()).reshape ((1 , num_devices ))
74- mesh = Mesh (device_mesh , axis_names = ('embed' , 'mlp' ))
7584
76- with mesh :
85+ with self . mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
7786 layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
7887 dummy_output = layer (dummy_caption )
7988 dummy_output .shape == (1 , 512 , 5120 )
8089
8190 def test_nnx_timestep_embedding (self ):
8291 key = jax .random .key (0 )
8392 rngs = nnx .Rngs (key )
84- num_devices = len (jax .devices ())
85- device_mesh = np .array (jax .devices ()).reshape ((1 , num_devices ))
86- mesh = Mesh (device_mesh , axis_names = ('embed' , 'mlp' ))
8793
8894 dummy_sample = jnp .ones ((1 , 256 ))
89- with mesh :
95+ with self . mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
9096 layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
9197 dummy_output = layer (dummy_sample )
9298 assert dummy_output .shape == (1 , 5120 )
@@ -110,9 +116,10 @@ def test_wan_time_text_embedding(self):
110116 time_freq_dim = 256
111117 time_proj_dim = 30720
112118 text_embed_dim = 4096
113- layer = WanTimeTextImageEmbedding (
114- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
115- )
119+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
120+ layer = WanTimeTextImageEmbedding (
121+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
122+ )
116123
117124 dummy_timestep = jnp .ones (batch_size )
118125
@@ -128,20 +135,8 @@ def test_wan_time_text_embedding(self):
128135 def test_wan_block (self ):
129136 key = jax .random .key (0 )
130137 rngs = nnx .Rngs (key )
131- pyconfig .initialize (
132- [
133- None ,
134- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
135- ],
136- unittest = True ,
137- )
138- config = pyconfig .config
139-
140- devices_array = create_device_mesh (config )
141-
142- flash_block_sizes = get_flash_block_sizes (config )
143138
144- mesh = Mesh ( devices_array , config . mesh_axes )
139+ flash_block_sizes = get_flash_block_sizes ( self . config )
145140
146141 dim = 5120
147142 ffn_dim = 13824
@@ -171,33 +166,24 @@ def test_wan_block(self):
171166 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
172167
173168 dummy_temb = jnp .ones ((batch_size , 6 , dim ))
174-
175- wan_block = WanTransformerBlock (
176- rngs = rngs ,
177- dim = dim ,
178- ffn_dim = ffn_dim ,
179- num_heads = num_heads ,
180- qk_norm = qk_norm ,
181- cross_attn_norm = cross_attn_norm ,
182- eps = eps ,
183- attention = "flash" ,
184- mesh = mesh ,
185- flash_block_sizes = flash_block_sizes ,
186- )
187- with mesh :
169+ with self . mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ):
170+ wan_block = WanTransformerBlock (
171+ rngs = rngs ,
172+ dim = dim ,
173+ ffn_dim = ffn_dim ,
174+ num_heads = num_heads ,
175+ qk_norm = qk_norm ,
176+ cross_attn_norm = cross_attn_norm ,
177+ eps = eps ,
178+ attention = "flash" ,
179+ mesh = self . mesh ,
180+ flash_block_sizes = flash_block_sizes ,
181+ )
182+ with self . mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
188183 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
189184 assert dummy_output .shape == dummy_hidden_states .shape
190185
191186 def test_wan_attention (self ):
192- pyconfig .initialize (
193- [
194- None ,
195- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
196- ],
197- unittest = True ,
198- )
199- config = pyconfig .config
200-
201187 batch_size = 1
202188 channels = 16
203189 frames = 21
@@ -210,59 +196,49 @@ def test_wan_attention(self):
210196
211197 key = jax .random .key (0 )
212198 rngs = nnx .Rngs (key )
213- devices_array = create_device_mesh (config )
214-
215- flash_block_sizes = get_flash_block_sizes (config )
199+ flash_block_sizes = get_flash_block_sizes (self .config )
216200
217- mesh = Mesh (devices_array , config .mesh_axes )
218201 batch_size = 1
219202 query_dim = 5120
220- attention = FlaxWanAttention (
221- rngs = rngs ,
222- query_dim = query_dim ,
223- heads = 40 ,
224- dim_head = 128 ,
225- attention_kernel = "flash" ,
226- mesh = mesh ,
227- flash_block_sizes = flash_block_sizes ,
228- )
203+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
204+ attention = FlaxWanAttention (
205+ rngs = rngs ,
206+ query_dim = query_dim ,
207+ heads = 40 ,
208+ dim_head = 128 ,
209+ attention_kernel = "flash" ,
210+ mesh = self .mesh ,
211+ flash_block_sizes = flash_block_sizes ,
212+ )
229213
230214 dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
231215
232216 dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
233217 dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
234- with mesh :
218+ with self . mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
235219 dummy_output = attention (
236220 hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
237221 )
238222 assert dummy_output .shape == dummy_hidden_states_shape
239223
240224 # dot product
241225 try :
242- attention = FlaxWanAttention (
243- rngs = rngs ,
244- query_dim = query_dim ,
245- heads = 40 ,
246- dim_head = 128 ,
247- attention_kernel = "dot_product" ,
248- split_head_dim = True ,
249- mesh = mesh ,
250- flash_block_sizes = flash_block_sizes ,
251- )
226+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
227+ attention = FlaxWanAttention (
228+ rngs = rngs ,
229+ query_dim = query_dim ,
230+ heads = 40 ,
231+ dim_head = 128 ,
232+ attention_kernel = "dot_product" ,
233+ split_head_dim = True ,
234+ mesh = self .mesh ,
235+ flash_block_sizes = flash_block_sizes ,
236+ )
252237 except NotImplementedError :
253238 pass
254239
255240 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
256241 def test_wan_model (self ):
257- pyconfig .initialize (
258- [
259- None ,
260- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
261- ],
262- unittest = True ,
263- )
264- config = pyconfig .config
265-
266242 batch_size = 1
267243 channels = 16
268244 frames = 1
@@ -273,18 +249,16 @@ def test_wan_model(self):
273249
274250 key = jax .random .key (0 )
275251 rngs = nnx .Rngs (key )
276- devices_array = create_device_mesh (config )
277-
278- flash_block_sizes = get_flash_block_sizes (config )
279252
280- mesh = Mesh ( devices_array , config . mesh_axes )
253+ flash_block_sizes = get_flash_block_sizes ( self . config )
281254 batch_size = 1
282255 num_layers = 1
283- wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
256+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
257+ wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = self .mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
284258
285259 dummy_timestep = jnp .ones ((batch_size ))
286260 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , 4096 ))
287- with mesh :
261+ with self . mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
288262 dummy_output = wan_model (
289263 hidden_states = dummy_hidden_states , timestep = dummy_timestep , encoder_hidden_states = dummy_encoder_hidden_states
290264 )
0 commit comments