File tree Expand file tree Collapse file tree
src/maxdiffusion/pipelines/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2222from flax .linen import partitioning as nn_partitioning
2323import jax
2424import jax .numpy as jnp
25- import numpy as np
2625from jax .sharding import NamedSharding , PartitionSpec as P
2726from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2827from ...max_utils import randn_tensor
@@ -87,7 +86,7 @@ def prepare_latents(
8786 latents : Optional [jax .Array ] = None ,
8887 last_image : Optional [jax .Array ] = None ,
8988 ) -> Tuple [jax .Array , jax .Array , Optional [jax .Array ]]:
90- num_channels_latents = self .vae .config [ ' z_dim' ]
89+ num_channels_latents = self .vae .z_dim
9190 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
9291 latent_height = height // self .vae_scale_factor_spatial
9392 latent_width = width // self .vae_scale_factor_spatial
Original file line number Diff line number Diff line change 2222from flax .linen import partitioning as nn_partitioning
2323import jax
2424import jax .numpy as jnp
25- import numpy as np
2625from jax .sharding import NamedSharding , PartitionSpec as P
2726from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2827from ...max_utils import randn_tensor
@@ -86,7 +85,7 @@ def prepare_latents(
8685 latents : Optional [jax .Array ] = None ,
8786 last_image : Optional [jax .Array ] = None ,
8887) -> Tuple [jax .Array , jax .Array , Optional [jax .Array ]]:
89- num_channels_latents = self .vae .config [ ' z_dim' ]
88+ num_channels_latents = self .vae .z_dim
9089 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
9190 latent_height = height // self .vae_scale_factor_spatial
9291 latent_width = width // self .vae_scale_factor_spatial
You can’t perform that action at this time.
0 commit comments