@@ -109,6 +109,16 @@ def __init__(
109109 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
110110 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
111111
112+ self .jitted_decode = jax .jit (
113+ partial (
114+ self .vae .decode ,
115+ feat_cache = self .vae_cache ,
116+ return_dict = False
117+ )
118+ )
119+
120+ self .p_run_inference = None
121+
112122 @classmethod
113123 def load_text_encoder (cls , config : HyperParameters ):
114124 text_encoder = UMT5EncoderModel .from_pretrained (
@@ -184,20 +194,27 @@ def load_scheduler(cls, config):
184194 return scheduler , scheduler_state
185195
186196 @classmethod
187- def from_pretrained (cls , config : HyperParameters ):
197+ def from_pretrained (cls , config : HyperParameters , vae_only = False ):
188198 devices_array = max_utils .create_device_mesh (config )
189199 mesh = Mesh (devices_array , config .mesh_axes )
190200 rng = jax .random .key (config .seed )
191201 rngs = nnx .Rngs (rng )
202+ transformer = None
203+ tokenizer = None
204+ scheduler = None
205+ scheduler_state = None
206+ text_encoder = None
207+ if not vae_only :
208+ with mesh :
209+ transformer = cls .load_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
210+
211+ text_encoder = cls .load_text_encoder (config = config )
212+ tokenizer = cls .load_tokenizer (config = config )
213+
214+ scheduler , scheduler_state = cls .load_scheduler (config = config )
192215
193216 with mesh :
194217 wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
195- transformer = cls .load_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
196-
197- text_encoder = cls .load_text_encoder (config = config )
198- tokenizer = cls .load_tokenizer (config = config )
199-
200- scheduler , scheduler_state = cls .load_scheduler (config = config )
201218
202219 return WanPipeline (
203220 tokenizer = tokenizer ,
@@ -291,10 +308,10 @@ def prepare_latents(
291308 num_latent_frames = (num_frames - 1 ) // vae_scale_factor_temporal + 1
292309 shape = (
293310 batch_size ,
311+ num_channels_latents ,
294312 num_latent_frames ,
295313 int (height ) // vae_scale_factor_spatial ,
296314 int (width ) // vae_scale_factor_spatial ,
297- num_channels_latents
298315 )
299316 latents = jax .random .normal (rng , shape = shape , dtype = self .config .weights_dtype )
300317
@@ -370,6 +387,7 @@ def __call__(
370387 scheduler = self .scheduler ,
371388 scheduler_state = scheduler_state
372389 )
390+
373391 with self .mesh :
374392 latents = p_run_inference (
375393 graphdef = graphdef ,
@@ -379,21 +397,13 @@ def __call__(
379397 prompt_embeds = prompt_embeds ,
380398 negative_prompt_embeds = negative_prompt_embeds
381399 )
382- latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , 1 , 1 , 1 , self . vae . z_dim )
383- latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , 1 , 1 , 1 , self . vae . z_dim )
400+ latents_mean = jnp .array (self .vae .latents_mean ).reshape (1 , self . vae . z_dim , 1 , 1 , 1 )
401+ latents_std = 1.0 / jnp .array (self .vae .latents_std ).reshape (1 , self . vae . z_dim , 1 , 1 , 1 )
384402 latents = latents / latents_std + latents_mean
385-
386403 latents = latents .astype (self .config .weights_dtype )
387404
388- jitted_decode = jax .jit (
389- partial (
390- self .vae .decode ,
391- feat_cache = self .vae_cache ,
392- return_dict = False
393- )
394- )
395405 with self .mesh :
396- video = jitted_decode (latents )[0 ]
406+ video = self . jitted_decode (latents )[0 ]
397407 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
398408 video = torch .from_numpy (np .array (video .astype (dtype = jnp .float32 ))).to (dtype = torch .bfloat16 )
399409 video = self .video_processor .postprocess_video (video , output_type = "np" )
@@ -431,6 +441,5 @@ def run_inference(
431441 if do_classifier_free_guidance :
432442 noise_uncond = transformer_forward_pass (graphdef , sharded_state , rest_of_state , latents , timestep , negative_prompt_embeds )
433443 noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
434-
435444 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
436445 return latents
0 commit comments