Skip to content

Commit e01925d

Browse files
committed
mesh axis names are different
1 parent da25a1b commit e01925d

1 file changed

Lines changed: 7 additions & 20 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from maxdiffusion.pyconfig import HyperParameters
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
4040
import qwix
41+
import numpy as np
4142

4243
RealQtRule = qwix.QtRule
4344

@@ -68,16 +69,9 @@ def test_nnx_pixart_alpha_text_projection(self):
6869
key = jax.random.key(0)
6970
rngs = nnx.Rngs(key)
7071
dummy_caption = jnp.ones((1, 512, 4096))
71-
pyconfig.initialize(
72-
[
73-
None,
74-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
75-
],
76-
unittest=True,
77-
)
78-
config = pyconfig.config
79-
devices_array = create_device_mesh(config)
80-
mesh = Mesh(devices_array, config.mesh_axes)
72+
num_devices = len(jax.devices())
73+
device_mesh = np.array(jax.devices()).reshape((1, num_devices))
74+
mesh = Mesh(device_mesh, axis_names=('embed', 'mlp'))
8175

8276
with mesh:
8377
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
@@ -87,16 +81,9 @@ def test_nnx_pixart_alpha_text_projection(self):
8781
def test_nnx_timestep_embedding(self):
8882
key = jax.random.key(0)
8983
rngs = nnx.Rngs(key)
90-
pyconfig.initialize(
91-
[
92-
None,
93-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
94-
],
95-
unittest=True,
96-
)
97-
config = pyconfig.config
98-
devices_array = create_device_mesh(config)
99-
mesh = Mesh(devices_array, config.mesh_axes)
84+
num_devices = len(jax.devices())
85+
device_mesh = np.array(jax.devices()).reshape((1, num_devices))
86+
mesh = Mesh(device_mesh, axis_names=('embed', 'mlp'))
10087

10188
dummy_sample = jnp.ones((1, 256))
10289
with mesh:

0 commit comments

Comments
 (0)