Skip to content

Commit d616736

Browse files
committed
Change in wan_vae_test.py
1 parent 3a1ccfd commit d616736

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,11 @@ def vae_encode(video, wan_vae, vae_cache, key):
522522
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
523523
wan_vae = nnx.merge(graphdef, params)
524524

525-
p_vae_encode = jax.jit(functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key))
525+
p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)
526526
original_video_shape = original_video.shape
527527
latent = p_vae_encode(original_video)
528528

529-
jitted_decode = jax.jit(functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False))
529+
jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)
530530
video = jitted_decode(latent)[0]
531531
video = jnp.transpose(video, (0, 4, 1, 2, 3))
532532
assert video.shape == original_video_shape

0 commit comments

Comments
 (0)