Skip to content

Commit 049fe51

Browse files
committed
test fix
1 parent 7861e25 commit 049fe51

1 file changed

Lines changed: 78 additions & 62 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
26-
26+
from flax.linen import partitioning as nn_partitioning
2727
from .. import pyconfig
2828
from ..max_utils import (create_device_mesh, get_flash_block_sizes)
2929
from ..models.wan.transformers.transformer_wan import (
@@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase):
4848

4949
def setUp(self):
5050
WanTransformerTest.dummy_data = {}
51+
pyconfig.initialize(
52+
[
53+
None,
54+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
55+
],
56+
unittest=True,
57+
)
58+
config = pyconfig.config
59+
self.config = config
60+
devices_array = create_device_mesh(config)
61+
self.mesh = Mesh(devices_array, config.mesh_axes)
62+
5163

5264
def test_rotary_pos_embed(self):
5365
batch_size = 1
@@ -65,28 +77,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6577
key = jax.random.key(0)
6678
rngs = nnx.Rngs(key)
6779
dummy_caption = jnp.ones((1, 512, 4096))
68-
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
69-
dummy_output = layer(dummy_caption)
70-
dummy_output.shape == (1, 512, 5120)
80+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
81+
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
82+
dummy_output = layer(dummy_caption)
83+
dummy_output.shape == (1, 512, 5120)
7184

7285
def test_nnx_timestep_embedding(self):
7386
key = jax.random.key(0)
7487
rngs = nnx.Rngs(key)
7588

7689
dummy_sample = jnp.ones((1, 256))
77-
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
78-
dummy_output = layer(dummy_sample)
79-
assert dummy_output.shape == (1, 5120)
90+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
91+
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
92+
dummy_output = layer(dummy_sample)
93+
assert dummy_output.shape == (1, 5120)
8094

8195
def test_fp32_layer_norm(self):
8296
key = jax.random.key(0)
8397
rngs = nnx.Rngs(key)
8498
batch_size = 1
8599
dummy_hidden_states = jnp.ones((batch_size, 75600, 5120))
86100
# expected same output shape with same dtype
87-
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
88-
dummy_output = layer(dummy_hidden_states)
89-
assert dummy_output.shape == dummy_hidden_states.shape
101+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
102+
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
103+
dummy_output = layer(dummy_hidden_states)
104+
assert dummy_output.shape == dummy_hidden_states.shape
90105

91106
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
92107
def test_wan_time_text_embedding(self):
@@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self):
97112
time_freq_dim = 256
98113
time_proj_dim = 30720
99114
text_embed_dim = 4096
100-
layer = WanTimeTextImageEmbedding(
101-
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
102-
)
115+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
116+
layer = WanTimeTextImageEmbedding(
117+
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
118+
)
103119

104-
dummy_timestep = jnp.ones(batch_size)
120+
dummy_timestep = jnp.ones(batch_size)
105121

106-
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
107-
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
108-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
109-
dummy_timestep, dummy_encoder_hidden_states
110-
)
111-
assert temb.shape == (batch_size, dim)
112-
assert timestep_proj.shape == (batch_size, time_proj_dim)
113-
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
122+
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
123+
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
124+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
125+
dummy_timestep, dummy_encoder_hidden_states
126+
)
127+
assert temb.shape == (batch_size, dim)
128+
assert timestep_proj.shape == (batch_size, time_proj_dim)
129+
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
114130

115131
def test_wan_block(self):
116132
key = jax.random.key(0)
@@ -158,19 +174,19 @@ def test_wan_block(self):
158174
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
159175

160176
dummy_temb = jnp.ones((batch_size, 6, dim))
161-
162-
wan_block = WanTransformerBlock(
163-
rngs=rngs,
164-
dim=dim,
165-
ffn_dim=ffn_dim,
166-
num_heads=num_heads,
167-
qk_norm=qk_norm,
168-
cross_attn_norm=cross_attn_norm,
169-
eps=eps,
170-
attention="flash",
171-
mesh=mesh,
172-
flash_block_sizes=flash_block_sizes,
173-
)
177+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
178+
wan_block = WanTransformerBlock(
179+
rngs=rngs,
180+
dim=dim,
181+
ffn_dim=ffn_dim,
182+
num_heads=num_heads,
183+
qk_norm=qk_norm,
184+
cross_attn_norm=cross_attn_norm,
185+
eps=eps,
186+
attention="flash",
187+
mesh=mesh,
188+
flash_block_sizes=flash_block_sizes,
189+
)
174190
with mesh:
175191
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
176192
assert dummy_output.shape == dummy_hidden_states.shape
@@ -204,40 +220,39 @@ def test_wan_attention(self):
204220
mesh = Mesh(devices_array, config.mesh_axes)
205221
batch_size = 1
206222
query_dim = 5120
207-
attention = FlaxWanAttention(
208-
rngs=rngs,
209-
query_dim=query_dim,
210-
heads=40,
211-
dim_head=128,
212-
attention_kernel="flash",
213-
mesh=mesh,
214-
flash_block_sizes=flash_block_sizes,
215-
)
216-
217-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
218-
219-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
220-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
221-
with mesh:
222-
dummy_output = attention(
223-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
224-
)
225-
assert dummy_output.shape == dummy_hidden_states_shape
226-
227-
# dot product
228-
try:
223+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
229224
attention = FlaxWanAttention(
230225
rngs=rngs,
231226
query_dim=query_dim,
232227
heads=40,
233228
dim_head=128,
234-
attention_kernel="dot_product",
235-
split_head_dim=True,
229+
attention_kernel="flash",
236230
mesh=mesh,
237231
flash_block_sizes=flash_block_sizes,
238232
)
239-
except NotImplementedError:
240-
pass
233+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
234+
235+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
236+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
237+
dummy_output = attention(
238+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
239+
)
240+
assert dummy_output.shape == dummy_hidden_states_shape
241+
242+
# dot product
243+
try:
244+
attention = FlaxWanAttention(
245+
rngs=rngs,
246+
query_dim=query_dim,
247+
heads=40,
248+
dim_head=128,
249+
attention_kernel="dot_product",
250+
split_head_dim=True,
251+
mesh=mesh,
252+
flash_block_sizes=flash_block_sizes,
253+
)
254+
except NotImplementedError:
255+
pass
241256

242257
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
243258
def test_wan_model(self):
@@ -267,7 +282,8 @@ def test_wan_model(self):
267282
mesh = Mesh(devices_array, config.mesh_axes)
268283
batch_size = 1
269284
num_layers = 1
270-
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
285+
with nn_partitioning.axis_rules(config.logical_axis_rules):
286+
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
271287

272288
dummy_timestep = jnp.ones((batch_size))
273289
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))

0 commit comments

Comments
 (0)