Skip to content

Commit da25a1b

Browse files
committed
fix
1 parent d462775 commit da25a1b

1 file changed

Lines changed: 23 additions & 5 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from maxdiffusion.pyconfig import HyperParameters
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
4040
import qwix
41-
import numpy as np
4241

4342
RealQtRule = qwix.QtRule
4443

@@ -69,9 +68,17 @@ def test_nnx_pixart_alpha_text_projection(self):
6968
key = jax.random.key(0)
7069
rngs = nnx.Rngs(key)
7170
dummy_caption = jnp.ones((1, 512, 4096))
72-
num_devices = len(jax.devices())
73-
device_mesh = np.array(jax.devices()).reshape((1, num_devices))
74-
mesh = Mesh(device_mesh, axis_names=('embed', 'mlp'))
71+
pyconfig.initialize(
72+
[
73+
None,
74+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
75+
],
76+
unittest=True,
77+
)
78+
config = pyconfig.config
79+
devices_array = create_device_mesh(config)
80+
mesh = Mesh(devices_array, config.mesh_axes)
81+
7582
with mesh:
7683
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
7784
dummy_output = layer(dummy_caption)
@@ -80,9 +87,20 @@ def test_nnx_pixart_alpha_text_projection(self):
8087
def test_nnx_timestep_embedding(self):
8188
key = jax.random.key(0)
8289
rngs = nnx.Rngs(key)
90+
pyconfig.initialize(
91+
[
92+
None,
93+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
94+
],
95+
unittest=True,
96+
)
97+
config = pyconfig.config
98+
devices_array = create_device_mesh(config)
99+
mesh = Mesh(devices_array, config.mesh_axes)
83100

84101
dummy_sample = jnp.ones((1, 256))
85-
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
102+
with mesh:
103+
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
86104
dummy_output = layer(dummy_sample)
87105
assert dummy_output.shape == (1, 5120)
88106

0 commit comments

Comments
 (0)