Skip to content

Commit 246d7a5

Browse files
committed
test fix
1 parent 7861e25 commit 246d7a5

1 file changed

Lines changed: 36 additions & 21 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ class WanTransformerTest(unittest.TestCase):
4848

4949
def setUp(self):
5050
WanTransformerTest.dummy_data = {}
51+
pyconfig.initialize(
52+
[
53+
None,
54+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
55+
],
56+
unittest=True,
57+
)
58+
config = pyconfig.config
59+
devices_array = create_device_mesh(config)
60+
self.mesh = Mesh(devices_array, config.mesh_axes)
61+
5162

5263
def test_rotary_pos_embed(self):
5364
batch_size = 1
@@ -65,28 +76,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6576
key = jax.random.key(0)
6677
rngs = nnx.Rngs(key)
6778
dummy_caption = jnp.ones((1, 512, 4096))
68-
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
69-
dummy_output = layer(dummy_caption)
70-
dummy_output.shape == (1, 512, 5120)
79+
with self.mesh:
80+
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
81+
dummy_output = layer(dummy_caption)
82+
dummy_output.shape == (1, 512, 5120)
7183

7284
def test_nnx_timestep_embedding(self):
7385
key = jax.random.key(0)
7486
rngs = nnx.Rngs(key)
7587

7688
dummy_sample = jnp.ones((1, 256))
77-
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
78-
dummy_output = layer(dummy_sample)
79-
assert dummy_output.shape == (1, 5120)
89+
with self.mesh:
90+
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
91+
dummy_output = layer(dummy_sample)
92+
assert dummy_output.shape == (1, 5120)
8093

8194
def test_fp32_layer_norm(self):
8295
key = jax.random.key(0)
8396
rngs = nnx.Rngs(key)
8497
batch_size = 1
8598
dummy_hidden_states = jnp.ones((batch_size, 75600, 5120))
8699
# expected same output shape with same dtype
87-
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
88-
dummy_output = layer(dummy_hidden_states)
89-
assert dummy_output.shape == dummy_hidden_states.shape
100+
with self.mesh:
101+
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
102+
dummy_output = layer(dummy_hidden_states)
103+
assert dummy_output.shape == dummy_hidden_states.shape
90104

91105
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
92106
def test_wan_time_text_embedding(self):
@@ -97,20 +111,21 @@ def test_wan_time_text_embedding(self):
97111
time_freq_dim = 256
98112
time_proj_dim = 30720
99113
text_embed_dim = 4096
100-
layer = WanTimeTextImageEmbedding(
101-
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
102-
)
114+
with self.mesh:
115+
layer = WanTimeTextImageEmbedding(
116+
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
117+
)
103118

104-
dummy_timestep = jnp.ones(batch_size)
119+
dummy_timestep = jnp.ones(batch_size)
105120

106-
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
107-
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
108-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
109-
dummy_timestep, dummy_encoder_hidden_states
110-
)
111-
assert temb.shape == (batch_size, dim)
112-
assert timestep_proj.shape == (batch_size, time_proj_dim)
113-
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
121+
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
122+
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
123+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
124+
dummy_timestep, dummy_encoder_hidden_states
125+
)
126+
assert temb.shape == (batch_size, dim)
127+
assert timestep_proj.shape == (batch_size, time_proj_dim)
128+
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
114129

115130
def test_wan_block(self):
116131
key = jax.random.key(0)

0 commit comments

Comments
 (0)