Skip to content

Commit 251e61f

Browse files
committed
convert images from pytorch tensors to JAX arrays
1 parent 6c0b551 commit 251e61f

2 files changed

Lines changed: 20 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ def prepare_latents(
8686
latents: Optional[jax.Array] = None,
8787
last_image: Optional[jax.Array] = None,
8888
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
89+
90+
if hasattr(image, "detach"):
91+
image = image.detach().cpu().numpy()
92+
image = jnp.array(image)
93+
94+
if last_image is not None:
95+
if hasattr(last_image, "detach"):
96+
last_image = last_image.detach().cpu().numpy()
97+
last_image = jnp.array(last_image)
98+
8999
num_channels_latents = self.vae.z_dim
90100
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
91101
latent_height = height // self.vae_scale_factor_spatial

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ def prepare_latents(
8585
latents: Optional[jax.Array] = None,
8686
last_image: Optional[jax.Array] = None,
8787
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
88+
89+
if hasattr(image, "detach"):
90+
image = image.detach().cpu().numpy()
91+
image = jnp.array(image)
92+
93+
if last_image is not None:
94+
if hasattr(last_image, "detach"):
95+
last_image = last_image.detach().cpu().numpy()
96+
last_image = jnp.array(last_image)
97+
8898
num_channels_latents = self.vae.z_dim
8999
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
90100
latent_height = height // self.vae_scale_factor_spatial

0 commit comments

Comments
 (0)