@@ -163,8 +163,8 @@ def test_wan_block(self):
163163 mesh = mesh ,
164164 flash_block_sizes = flash_block_sizes ,
165165 )
166-
167- dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
166+ with mesh :
167+ dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
168168 assert dummy_output .shape == dummy_hidden_states .shape
169169
170170 def test_wan_attention (self ):
@@ -210,10 +210,10 @@ def test_wan_attention(self):
210210
211211 dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
212212 dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
213-
214- dummy_output = attention (
215- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
216- )
213+ with mesh :
214+ dummy_output = attention (
215+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
216+ )
217217 assert dummy_output .shape == dummy_hidden_states_shape
218218
219219 # dot product
@@ -246,7 +246,7 @@ def test_wan_model(self):
246246 frames = 21
247247 height = 90
248248 width = 160
249- hidden_states_shape = (batch_size , frames , height , width , channels )
249+ hidden_states_shape = (batch_size , channels , frames , height , width )
250250 dummy_hidden_states = jnp .ones (hidden_states_shape )
251251
252252 key = jax .random .key (0 )
@@ -266,10 +266,14 @@ def test_wan_model(self):
266266
267267 dummy_timestep = jnp .ones ((batch_size ))
268268 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , 4096 ))
269-
270- dummy_output = wan_model (
271- hidden_states = dummy_hidden_states , timestep = dummy_timestep , encoder_hidden_states = dummy_encoder_hidden_states
272- )
269+ with mesh :
270+ dummy_output = wan_model (
271+ hidden_states = dummy_hidden_states ,
272+ timestep = dummy_timestep ,
273+ encoder_hidden_states = dummy_encoder_hidden_states ,
274+ is_uncond = jnp .array (True , dtype = jnp .bool_ ),
275+ slg_mask = jnp .zeros (40 , dtype = jnp .bool_ )
276+ )
273277 assert dummy_output .shape == hidden_states_shape
274278
275279
0 commit comments