Skip to content

Commit bda990b

Browse files
committed
fixing vae tests
1 parent f46674b commit bda990b

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,13 @@ def vae_encode(video, wan_vae, vae_cache, key):
560560

561561
p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)
562562
original_video_shape = original_video.shape
563-
latent = p_vae_encode(original_video)
564563

565-
jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)
566-
video = jitted_decode(latent)[0]
564+
# Execute VAE operations within mesh context to match sharding constraints
565+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
566+
latent = p_vae_encode(original_video)
567+
568+
jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)
569+
video = jitted_decode(latent)[0]
567570
video = jnp.transpose(video, (0, 4, 1, 2, 3))
568571
assert video.shape == original_video_shape
569572

0 commit comments

Comments
 (0)