2323from absl .testing import absltest
2424from flax import nnx
2525from jax .sharding import Mesh
26-
26+ from flax . linen import partitioning as nn_partitioning
2727from .. import pyconfig
2828from ..max_utils import (create_device_mesh , get_flash_block_sizes )
2929from ..models .wan .transformers .transformer_wan import (
@@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase):
4848
4949 def setUp (self ):
5050 WanTransformerTest .dummy_data = {}
51+ pyconfig .initialize (
52+ [
53+ None ,
54+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
55+ ],
56+ unittest = True ,
57+ )
58+ config = pyconfig .config
59+ self .config = config
60+ devices_array = create_device_mesh (config )
61+ self .mesh = Mesh (devices_array , config .mesh_axes )
62+
5163
5264 def test_rotary_pos_embed (self ):
5365 batch_size = 1
@@ -65,28 +77,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6577 key = jax .random .key (0 )
6678 rngs = nnx .Rngs (key )
6779 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
68- layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
69- dummy_output = layer (dummy_caption )
70- dummy_output .shape == (1 , 512 , 5120 )
80+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
81+ layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
82+ dummy_output = layer (dummy_caption )
83+ dummy_output .shape == (1 , 512 , 5120 )
7184
7285 def test_nnx_timestep_embedding (self ):
7386 key = jax .random .key (0 )
7487 rngs = nnx .Rngs (key )
7588
7689 dummy_sample = jnp .ones ((1 , 256 ))
77- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
78- dummy_output = layer (dummy_sample )
79- assert dummy_output .shape == (1 , 5120 )
90+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
91+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
92+ dummy_output = layer (dummy_sample )
93+ assert dummy_output .shape == (1 , 5120 )
8094
8195 def test_fp32_layer_norm (self ):
8296 key = jax .random .key (0 )
8397 rngs = nnx .Rngs (key )
8498 batch_size = 1
8599 dummy_hidden_states = jnp .ones ((batch_size , 75600 , 5120 ))
86100 # expected same output shape with same dtype
87- layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
88- dummy_output = layer (dummy_hidden_states )
89- assert dummy_output .shape == dummy_hidden_states .shape
101+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
102+ layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
103+ dummy_output = layer (dummy_hidden_states )
104+ assert dummy_output .shape == dummy_hidden_states .shape
90105
91106 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
92107 def test_wan_time_text_embedding (self ):
@@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self):
97112 time_freq_dim = 256
98113 time_proj_dim = 30720
99114 text_embed_dim = 4096
100- layer = WanTimeTextImageEmbedding (
101- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
102- )
115+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
116+ layer = WanTimeTextImageEmbedding (
117+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
118+ )
103119
104- dummy_timestep = jnp .ones (batch_size )
120+ dummy_timestep = jnp .ones (batch_size )
105121
106- encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
107- dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
108- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
109- dummy_timestep , dummy_encoder_hidden_states
110- )
111- assert temb .shape == (batch_size , dim )
112- assert timestep_proj .shape == (batch_size , time_proj_dim )
113- assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
122+ encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
123+ dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
124+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
125+ dummy_timestep , dummy_encoder_hidden_states
126+ )
127+ assert temb .shape == (batch_size , dim )
128+ assert timestep_proj .shape == (batch_size , time_proj_dim )
129+ assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
114130
115131 def test_wan_block (self ):
116132 key = jax .random .key (0 )
@@ -171,7 +187,7 @@ def test_wan_block(self):
171187 mesh = mesh ,
172188 flash_block_sizes = flash_block_sizes ,
173189 )
174- with mesh :
190+ with mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
175191 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
176192 assert dummy_output .shape == dummy_hidden_states .shape
177193
@@ -218,7 +234,7 @@ def test_wan_attention(self):
218234
219235 dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
220236 dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
221- with mesh :
237+ with mesh , nn_partitioning . axis_rules ( self . config . logical_axis_rules ) :
222238 dummy_output = attention (
223239 hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
224240 )
0 commit comments