2323from absl .testing import absltest
2424from flax import nnx
2525from jax .sharding import Mesh
26-
26+ from flax . linen import partitioning as nn_partitioning
2727from .. import pyconfig
2828from ..max_utils import (create_device_mesh , get_flash_block_sizes )
2929from ..models .wan .transformers .transformer_wan import (
@@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase):
4848
4949 def setUp (self ):
5050 WanTransformerTest .dummy_data = {}
51+ pyconfig .initialize (
52+ [
53+ None ,
54+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
55+ ],
56+ unittest = True ,
57+ )
58+ config = pyconfig .config
59+ self .config = config
60+ devices_array = create_device_mesh (config )
61+ self .mesh = Mesh (devices_array , config .mesh_axes )
62+
5163
5264 def test_rotary_pos_embed (self ):
5365 batch_size = 1
@@ -65,28 +77,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6577 key = jax .random .key (0 )
6678 rngs = nnx .Rngs (key )
6779 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
68- layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
69- dummy_output = layer (dummy_caption )
70- dummy_output .shape == (1 , 512 , 5120 )
80+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
81+ layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
82+ dummy_output = layer (dummy_caption )
83+ dummy_output .shape == (1 , 512 , 5120 )
7184
7285 def test_nnx_timestep_embedding (self ):
7386 key = jax .random .key (0 )
7487 rngs = nnx .Rngs (key )
7588
7689 dummy_sample = jnp .ones ((1 , 256 ))
77- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
78- dummy_output = layer (dummy_sample )
79- assert dummy_output .shape == (1 , 5120 )
90+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
91+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
92+ dummy_output = layer (dummy_sample )
93+ assert dummy_output .shape == (1 , 5120 )
8094
8195 def test_fp32_layer_norm (self ):
8296 key = jax .random .key (0 )
8397 rngs = nnx .Rngs (key )
8498 batch_size = 1
8599 dummy_hidden_states = jnp .ones ((batch_size , 75600 , 5120 ))
86100 # expected same output shape with same dtype
87- layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
88- dummy_output = layer (dummy_hidden_states )
89- assert dummy_output .shape == dummy_hidden_states .shape
101+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
102+ layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
103+ dummy_output = layer (dummy_hidden_states )
104+ assert dummy_output .shape == dummy_hidden_states .shape
90105
91106 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
92107 def test_wan_time_text_embedding (self ):
@@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self):
97112 time_freq_dim = 256
98113 time_proj_dim = 30720
99114 text_embed_dim = 4096
100- layer = WanTimeTextImageEmbedding (
101- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
102- )
115+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
116+ layer = WanTimeTextImageEmbedding (
117+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
118+ )
103119
104- dummy_timestep = jnp .ones (batch_size )
120+ dummy_timestep = jnp .ones (batch_size )
105121
106- encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
107- dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
108- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
109- dummy_timestep , dummy_encoder_hidden_states
110- )
111- assert temb .shape == (batch_size , dim )
112- assert timestep_proj .shape == (batch_size , time_proj_dim )
113- assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
122+ encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
123+ dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
124+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
125+ dummy_timestep , dummy_encoder_hidden_states
126+ )
127+ assert temb .shape == (batch_size , dim )
128+ assert timestep_proj .shape == (batch_size , time_proj_dim )
129+ assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
114130
115131 def test_wan_block (self ):
116132 key = jax .random .key (0 )
@@ -158,19 +174,19 @@ def test_wan_block(self):
158174 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
159175
160176 dummy_temb = jnp .ones ((batch_size , 6 , dim ))
161-
162- wan_block = WanTransformerBlock (
163- rngs = rngs ,
164- dim = dim ,
165- ffn_dim = ffn_dim ,
166- num_heads = num_heads ,
167- qk_norm = qk_norm ,
168- cross_attn_norm = cross_attn_norm ,
169- eps = eps ,
170- attention = "flash" ,
171- mesh = mesh ,
172- flash_block_sizes = flash_block_sizes ,
173- )
177+ with nn_partitioning . axis_rules ( self . config . logical_axis_rules ):
178+ wan_block = WanTransformerBlock (
179+ rngs = rngs ,
180+ dim = dim ,
181+ ffn_dim = ffn_dim ,
182+ num_heads = num_heads ,
183+ qk_norm = qk_norm ,
184+ cross_attn_norm = cross_attn_norm ,
185+ eps = eps ,
186+ attention = "flash" ,
187+ mesh = mesh ,
188+ flash_block_sizes = flash_block_sizes ,
189+ )
174190 with mesh :
175191 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
176192 assert dummy_output .shape == dummy_hidden_states .shape
@@ -204,40 +220,39 @@ def test_wan_attention(self):
204220 mesh = Mesh (devices_array , config .mesh_axes )
205221 batch_size = 1
206222 query_dim = 5120
207- attention = FlaxWanAttention (
208- rngs = rngs ,
209- query_dim = query_dim ,
210- heads = 40 ,
211- dim_head = 128 ,
212- attention_kernel = "flash" ,
213- mesh = mesh ,
214- flash_block_sizes = flash_block_sizes ,
215- )
216-
217- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
218-
219- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
220- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
221- with mesh :
222- dummy_output = attention (
223- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
224- )
225- assert dummy_output .shape == dummy_hidden_states_shape
226-
227- # dot product
228- try :
223+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
229224 attention = FlaxWanAttention (
230225 rngs = rngs ,
231226 query_dim = query_dim ,
232227 heads = 40 ,
233228 dim_head = 128 ,
234- attention_kernel = "dot_product" ,
235- split_head_dim = True ,
229+ attention_kernel = "flash" ,
236230 mesh = mesh ,
237231 flash_block_sizes = flash_block_sizes ,
238232 )
239- except NotImplementedError :
240- pass
233+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
234+
235+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
236+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
237+ dummy_output = attention (
238+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
239+ )
240+ assert dummy_output .shape == dummy_hidden_states_shape
241+
242+ # dot product
243+ try :
244+ attention = FlaxWanAttention (
245+ rngs = rngs ,
246+ query_dim = query_dim ,
247+ heads = 40 ,
248+ dim_head = 128 ,
249+ attention_kernel = "dot_product" ,
250+ split_head_dim = True ,
251+ mesh = mesh ,
252+ flash_block_sizes = flash_block_sizes ,
253+ )
254+ except NotImplementedError :
255+ pass
241256
242257 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
243258 def test_wan_model (self ):
@@ -267,7 +282,8 @@ def test_wan_model(self):
267282 mesh = Mesh (devices_array , config .mesh_axes )
268283 batch_size = 1
269284 num_layers = 1
270- wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
285+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
286+ wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
271287
272288 dummy_timestep = jnp .ones ((batch_size ))
273289 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , 4096 ))
0 commit comments