Skip to content

Commit d473a33

Browse files
committed
add mesh
1 parent e01925d commit d473a33

1 file changed

Lines changed: 60 additions & 86 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 60 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
4040
import qwix
4141
import numpy as np
42+
from flax.linen import partitioning as nn_partitioning
4243

4344
RealQtRule = qwix.QtRule
4445

@@ -52,6 +53,17 @@ class WanTransformerTest(unittest.TestCase):
5253

5354
def setUp(self):
5455
WanTransformerTest.dummy_data = {}
56+
pyconfig.initialize(
57+
[
58+
None,
59+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
60+
],
61+
unittest=True,
62+
)
63+
self.config = pyconfig.config
64+
devices_array = create_device_mesh(self.config)
65+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
66+
5567

5668
def test_rotary_pos_embed(self):
5769
batch_size = 1
@@ -69,24 +81,18 @@ def test_nnx_pixart_alpha_text_projection(self):
6981
key = jax.random.key(0)
7082
rngs = nnx.Rngs(key)
7183
dummy_caption = jnp.ones((1, 512, 4096))
72-
num_devices = len(jax.devices())
73-
device_mesh = np.array(jax.devices()).reshape((1, num_devices))
74-
mesh = Mesh(device_mesh, axis_names=('embed', 'mlp'))
7584

76-
with mesh:
85+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
7786
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
7887
dummy_output = layer(dummy_caption)
7988
dummy_output.shape == (1, 512, 5120)
8089

8190
def test_nnx_timestep_embedding(self):
8291
key = jax.random.key(0)
8392
rngs = nnx.Rngs(key)
84-
num_devices = len(jax.devices())
85-
device_mesh = np.array(jax.devices()).reshape((1, num_devices))
86-
mesh = Mesh(device_mesh, axis_names=('embed', 'mlp'))
8793

8894
dummy_sample = jnp.ones((1, 256))
89-
with mesh:
95+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
9096
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
9197
dummy_output = layer(dummy_sample)
9298
assert dummy_output.shape == (1, 5120)
@@ -110,9 +116,10 @@ def test_wan_time_text_embedding(self):
110116
time_freq_dim = 256
111117
time_proj_dim = 30720
112118
text_embed_dim = 4096
113-
layer = WanTimeTextImageEmbedding(
114-
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
115-
)
119+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
120+
layer = WanTimeTextImageEmbedding(
121+
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
122+
)
116123

117124
dummy_timestep = jnp.ones(batch_size)
118125

@@ -128,20 +135,8 @@ def test_wan_time_text_embedding(self):
128135
def test_wan_block(self):
129136
key = jax.random.key(0)
130137
rngs = nnx.Rngs(key)
131-
pyconfig.initialize(
132-
[
133-
None,
134-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
135-
],
136-
unittest=True,
137-
)
138-
config = pyconfig.config
139-
140-
devices_array = create_device_mesh(config)
141-
142-
flash_block_sizes = get_flash_block_sizes(config)
143138

144-
mesh = Mesh(devices_array, config.mesh_axes)
139+
flash_block_sizes = get_flash_block_sizes(self.config)
145140

146141
dim = 5120
147142
ffn_dim = 13824
@@ -171,33 +166,24 @@ def test_wan_block(self):
171166
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
172167

173168
dummy_temb = jnp.ones((batch_size, 6, dim))
174-
175-
wan_block = WanTransformerBlock(
176-
rngs=rngs,
177-
dim=dim,
178-
ffn_dim=ffn_dim,
179-
num_heads=num_heads,
180-
qk_norm=qk_norm,
181-
cross_attn_norm=cross_attn_norm,
182-
eps=eps,
183-
attention="flash",
184-
mesh=mesh,
185-
flash_block_sizes=flash_block_sizes,
186-
)
187-
with mesh:
169+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
170+
wan_block = WanTransformerBlock(
171+
rngs=rngs,
172+
dim=dim,
173+
ffn_dim=ffn_dim,
174+
num_heads=num_heads,
175+
qk_norm=qk_norm,
176+
cross_attn_norm=cross_attn_norm,
177+
eps=eps,
178+
attention="flash",
179+
mesh=self.mesh,
180+
flash_block_sizes=flash_block_sizes,
181+
)
182+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
188183
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
189184
assert dummy_output.shape == dummy_hidden_states.shape
190185

191186
def test_wan_attention(self):
192-
pyconfig.initialize(
193-
[
194-
None,
195-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
196-
],
197-
unittest=True,
198-
)
199-
config = pyconfig.config
200-
201187
batch_size = 1
202188
channels = 16
203189
frames = 21
@@ -210,59 +196,49 @@ def test_wan_attention(self):
210196

211197
key = jax.random.key(0)
212198
rngs = nnx.Rngs(key)
213-
devices_array = create_device_mesh(config)
214-
215-
flash_block_sizes = get_flash_block_sizes(config)
199+
flash_block_sizes = get_flash_block_sizes(self.config)
216200

217-
mesh = Mesh(devices_array, config.mesh_axes)
218201
batch_size = 1
219202
query_dim = 5120
220-
attention = FlaxWanAttention(
221-
rngs=rngs,
222-
query_dim=query_dim,
223-
heads=40,
224-
dim_head=128,
225-
attention_kernel="flash",
226-
mesh=mesh,
227-
flash_block_sizes=flash_block_sizes,
228-
)
203+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
204+
attention = FlaxWanAttention(
205+
rngs=rngs,
206+
query_dim=query_dim,
207+
heads=40,
208+
dim_head=128,
209+
attention_kernel="flash",
210+
mesh=self.mesh,
211+
flash_block_sizes=flash_block_sizes,
212+
)
229213

230214
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
231215

232216
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
233217
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
234-
with mesh:
218+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
235219
dummy_output = attention(
236220
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
237221
)
238222
assert dummy_output.shape == dummy_hidden_states_shape
239223

240224
# dot product
241225
try:
242-
attention = FlaxWanAttention(
243-
rngs=rngs,
244-
query_dim=query_dim,
245-
heads=40,
246-
dim_head=128,
247-
attention_kernel="dot_product",
248-
split_head_dim=True,
249-
mesh=mesh,
250-
flash_block_sizes=flash_block_sizes,
251-
)
226+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
227+
attention = FlaxWanAttention(
228+
rngs=rngs,
229+
query_dim=query_dim,
230+
heads=40,
231+
dim_head=128,
232+
attention_kernel="dot_product",
233+
split_head_dim=True,
234+
mesh=self.mesh,
235+
flash_block_sizes=flash_block_sizes,
236+
)
252237
except NotImplementedError:
253238
pass
254239

255240
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
256241
def test_wan_model(self):
257-
pyconfig.initialize(
258-
[
259-
None,
260-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
261-
],
262-
unittest=True,
263-
)
264-
config = pyconfig.config
265-
266242
batch_size = 1
267243
channels = 16
268244
frames = 1
@@ -273,18 +249,16 @@ def test_wan_model(self):
273249

274250
key = jax.random.key(0)
275251
rngs = nnx.Rngs(key)
276-
devices_array = create_device_mesh(config)
277-
278-
flash_block_sizes = get_flash_block_sizes(config)
279252

280-
mesh = Mesh(devices_array, config.mesh_axes)
253+
flash_block_sizes = get_flash_block_sizes(self.config)
281254
batch_size = 1
282255
num_layers = 1
283-
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
256+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
257+
wan_model = WanModel(rngs=rngs, attention="flash", mesh=self.mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
284258

285259
dummy_timestep = jnp.ones((batch_size))
286260
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))
287-
with mesh:
261+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
288262
dummy_output = wan_model(
289263
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
290264
)

0 commit comments

Comments
 (0)