Skip to content

Commit 6395edc

Browse files
committed
modified wanpipeline encode decode calls
1 parent f49f5d9 commit 6395edc

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def prepare_latents_i2v_base(
529529
video_condition = video_condition.astype(vae_dtype)
530530

531531
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
532-
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
532+
encoded_output = self.vae.encode(video_condition)[0].mode()
533533

534534
# Normalize latents
535535
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim)
@@ -550,7 +550,7 @@ def _denormalize_latents(self, latents: jax.Array) -> jax.Array:
550550
def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray:
551551
"""Decodes latents to video frames and postprocesses."""
552552
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
553-
video = self.vae.decode(latents, self.vae_cache)[0]
553+
video = self.vae.decode(latents)[0]
554554

555555
video = jnp.transpose(video, (0, 4, 1, 2, 3))
556556
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)

0 commit comments

Comments
 (0)