Skip to content

Commit d462775

Browse files
committed
fix unit test
1 parent 01712be commit d462775

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
2-
jax==0.7.0
2+
jax>=0.6.2
33
jaxlib>=0.4.30
44
grain
55
google-cloud-storage>=2.17.0

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from maxdiffusion.pyconfig import HyperParameters
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
4040
import qwix
41+
import numpy as np
4142

4243
RealQtRule = qwix.QtRule
4344

@@ -68,7 +69,11 @@ def test_nnx_pixart_alpha_text_projection(self):
6869
key = jax.random.key(0)
6970
rngs = nnx.Rngs(key)
7071
dummy_caption = jnp.ones((1, 512, 4096))
71-
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
72+
num_devices = len(jax.devices())
73+
device_mesh = np.array(jax.devices()).reshape((1, num_devices))
74+
mesh = Mesh(device_mesh, axis_names=('embed', 'mlp'))
75+
with mesh:
76+
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
7277
dummy_output = layer(dummy_caption)
7378
dummy_output.shape == (1, 512, 5120)
7479

0 commit comments

Comments
 (0)