@@ -434,12 +434,13 @@ def __call__(
434434 prompt_embeds = prompt_embeds ,
435435 negative_prompt_embeds = negative_prompt_embeds ,
436436 )
437- latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
438- latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
439- latents = latents / latents_std + latents_mean
440- latents = latents .astype (self .config .weights_dtype )
441-
442- video = self .vae .decode (latents , self .vae_cache )[0 ]
437+ latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
438+ latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
439+ latents = latents / latents_std + latents_mean
440+ latents = latents .astype (self .config .weights_dtype )
441+
442+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
443+ video = self .vae .decode (latents , self .vae_cache )[0 ]
443444
444445 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
445446 video = torch .from_numpy (np .array (video .astype (dtype = jnp .float32 ))).to (dtype = torch .bfloat16 )
0 commit comments