File tree Expand file tree Collapse file tree
src/maxdiffusion/pipelines/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -86,6 +86,16 @@ def prepare_latents(
8686 latents : Optional [jax .Array ] = None ,
8787 last_image : Optional [jax .Array ] = None ,
8888 ) -> Tuple [jax .Array , jax .Array , Optional [jax .Array ]]:
89+
90+ if hasattr (image , "detach" ):
91+ image = image .detach ().cpu ().numpy ()
92+ image = jnp .array (image )
93+
94+ if last_image is not None :
95+ if hasattr (last_image , "detach" ):
96+ last_image = last_image .detach ().cpu ().numpy ()
97+ last_image = jnp .array (last_image )
98+
8999 num_channels_latents = self .vae .z_dim
90100 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
91101 latent_height = height // self .vae_scale_factor_spatial
Original file line number Diff line number Diff line change @@ -85,6 +85,16 @@ def prepare_latents(
8585 latents : Optional [jax .Array ] = None ,
8686 last_image : Optional [jax .Array ] = None ,
8787) -> Tuple [jax .Array , jax .Array , Optional [jax .Array ]]:
88+
89+ if hasattr (image , "detach" ):
90+ image = image .detach ().cpu ().numpy ()
91+ image = jnp .array (image )
92+
93+ if last_image is not None :
94+ if hasattr (last_image , "detach" ):
95+ last_image = last_image .detach ().cpu ().numpy ()
96+ last_image = jnp .array (last_image )
97+
8898 num_channels_latents = self .vae .z_dim
8999 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
90100 latent_height = height // self .vae_scale_factor_spatial
You can’t perform that action at this time.
0 commit comments