@@ -234,8 +234,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
234234 subfolder = "vae" ,
235235 rngs = rngs ,
236236 mesh = mesh ,
237- dtype = config . activations_dtype ,
238- weights_dtype = config . weights_dtype ,
237+ dtype = jnp . float32 ,
238+ weights_dtype = jnp . float32 ,
239239 )
240240 return wan_vae
241241
@@ -494,7 +494,7 @@ def encode_prompt(
494494 num_videos_per_prompt = num_videos_per_prompt ,
495495 max_sequence_length = max_sequence_length ,
496496 )
497- prompt_embeds = jnp .array (prompt_embeds .detach ().numpy (), dtype = self . config . weights_dtype )
497+ prompt_embeds = jnp .array (prompt_embeds .detach ().numpy (), dtype = jnp . float32 )
498498
499499 if negative_prompt_embeds is None :
500500 negative_prompt = negative_prompt or ""
@@ -504,7 +504,7 @@ def encode_prompt(
504504 num_videos_per_prompt = num_videos_per_prompt ,
505505 max_sequence_length = max_sequence_length ,
506506 )
507- negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().numpy (), dtype = self . config . weights_dtype )
507+ negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().numpy (), dtype = jnp . float32 )
508508
509509 return prompt_embeds , negative_prompt_embeds
510510
@@ -527,7 +527,7 @@ def prepare_latents(
527527 int (height ) // vae_scale_factor_spatial ,
528528 int (width ) // vae_scale_factor_spatial ,
529529 )
530- latents = jax .random .normal (rng , shape = shape , dtype = self . config . weights_dtype )
530+ latents = jax .random .normal (rng , shape = shape , dtype = jnp . float32 )
531531
532532 return latents
533533
@@ -617,7 +617,7 @@ def __call__(
617617 latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
618618 latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , self .vae .z_dim , 1 , 1 , 1 )
619619 latents = latents / latents_std + latents_mean
620- latents = latents .astype (self . config . weights_dtype )
620+ latents = latents .astype (jnp . float32 )
621621
622622 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
623623 video = self .vae .decode (latents , self .vae_cache )[0 ]
0 commit comments