@@ -48,6 +48,17 @@ class WanTransformerTest(unittest.TestCase):
4848
4949 def setUp (self ):
5050 WanTransformerTest .dummy_data = {}
51+ pyconfig .initialize (
52+ [
53+ None ,
54+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
55+ ],
56+ unittest = True ,
57+ )
58+ config = pyconfig .config
59+ devices_array = create_device_mesh (config )
60+ self .mesh = Mesh (devices_array , config .mesh_axes )
61+
5162
5263 def test_rotary_pos_embed (self ):
5364 batch_size = 1
@@ -65,28 +76,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6576 key = jax .random .key (0 )
6677 rngs = nnx .Rngs (key )
6778 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
68- layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
69- dummy_output = layer (dummy_caption )
70- dummy_output .shape == (1 , 512 , 5120 )
79+ with self .mesh :
80+ layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
81+ dummy_output = layer (dummy_caption )
82+ dummy_output .shape == (1 , 512 , 5120 )
7183
7284 def test_nnx_timestep_embedding (self ):
7385 key = jax .random .key (0 )
7486 rngs = nnx .Rngs (key )
7587
7688 dummy_sample = jnp .ones ((1 , 256 ))
77- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
78- dummy_output = layer (dummy_sample )
79- assert dummy_output .shape == (1 , 5120 )
89+ with self .mesh :
90+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
91+ dummy_output = layer (dummy_sample )
92+ assert dummy_output .shape == (1 , 5120 )
8093
8194 def test_fp32_layer_norm (self ):
8295 key = jax .random .key (0 )
8396 rngs = nnx .Rngs (key )
8497 batch_size = 1
8598 dummy_hidden_states = jnp .ones ((batch_size , 75600 , 5120 ))
8699 # expected same output shape with same dtype
87- layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
88- dummy_output = layer (dummy_hidden_states )
89- assert dummy_output .shape == dummy_hidden_states .shape
100+ with self .mesh :
101+ layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
102+ dummy_output = layer (dummy_hidden_states )
103+ assert dummy_output .shape == dummy_hidden_states .shape
90104
91105 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
92106 def test_wan_time_text_embedding (self ):
@@ -97,20 +111,21 @@ def test_wan_time_text_embedding(self):
97111 time_freq_dim = 256
98112 time_proj_dim = 30720
99113 text_embed_dim = 4096
100- layer = WanTimeTextImageEmbedding (
101- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
102- )
114+ with self .mesh :
115+ layer = WanTimeTextImageEmbedding (
116+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
117+ )
103118
104- dummy_timestep = jnp .ones (batch_size )
119+ dummy_timestep = jnp .ones (batch_size )
105120
106- encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
107- dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
108- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
109- dummy_timestep , dummy_encoder_hidden_states
110- )
111- assert temb .shape == (batch_size , dim )
112- assert timestep_proj .shape == (batch_size , time_proj_dim )
113- assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
121+ encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
122+ dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
123+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
124+ dummy_timestep , dummy_encoder_hidden_states
125+ )
126+ assert temb .shape == (batch_size , dim )
127+ assert timestep_proj .shape == (batch_size , time_proj_dim )
128+ assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
114129
115130 def test_wan_block (self ):
116131 key = jax .random .key (0 )
0 commit comments