Skip to content

Commit 662e14a

Browse files
committed
vae.config access changed
1 parent 83bcf33 commit 662e14a

2 files changed

Lines changed: 2 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from flax.linen import partitioning as nn_partitioning
2323
import jax
2424
import jax.numpy as jnp
25-
import numpy as np
2625
from jax.sharding import NamedSharding, PartitionSpec as P
2726
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2827
from ...max_utils import randn_tensor
@@ -87,7 +86,7 @@ def prepare_latents(
8786
latents: Optional[jax.Array] = None,
8887
last_image: Optional[jax.Array] = None,
8988
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
90-
num_channels_latents = self.vae.config['z_dim']
89+
num_channels_latents = self.vae.z_dim
9190
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
9291
latent_height = height // self.vae_scale_factor_spatial
9392
latent_width = width // self.vae_scale_factor_spatial

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from flax.linen import partitioning as nn_partitioning
2323
import jax
2424
import jax.numpy as jnp
25-
import numpy as np
2625
from jax.sharding import NamedSharding, PartitionSpec as P
2726
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2827
from ...max_utils import randn_tensor
@@ -86,7 +85,7 @@ def prepare_latents(
8685
latents: Optional[jax.Array] = None,
8786
last_image: Optional[jax.Array] = None,
8887
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
89-
num_channels_latents = self.vae.config['z_dim']
88+
num_channels_latents = self.vae.z_dim
9089
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
9190
latent_height = height // self.vae_scale_factor_spatial
9291
latent_width = width // self.vae_scale_factor_spatial

0 commit comments

Comments
 (0)