@@ -213,7 +213,7 @@ def __init__(
213213 self .devices_array = devices_array
214214 self .mesh = mesh
215215 self .config = config
216- self .run_wan2_2 = config .model_name == "wan2.2"
216+ self .model_name = config .model_name
217217
218218 self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
219219 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
@@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
379379 mesh = Mesh (devices_array , config .mesh_axes )
380380 rng = jax .random .key (config .seed )
381381 rngs = nnx .Rngs (rng )
382- run_wan2_2 = config .model_name == "wan2.2"
382+ model_name = config .model_name
383383 low_noise_transformer = None
384384 high_noise_transformer = None
385385 tokenizer = None
@@ -390,7 +390,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
390390 if load_transformer :
391391 with mesh :
392392 low_noise_transformer = cls .load_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config , restored_checkpoint = restored_checkpoint , subfolder = "transformer" )
393- if run_wan2_2 :
393+ if model_name == "wan2.2" :
394394 high_noise_transformer = cls .load_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config , restored_checkpoint = restored_checkpoint , subfolder = "transformer_2" )
395395
396396 text_encoder = cls .load_text_encoder (config = config )
@@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
421421 mesh = Mesh (devices_array , config .mesh_axes )
422422 rng = jax .random .key (config .seed )
423423 rngs = nnx .Rngs (rng )
424- run_wan2_2 = config .model_name == "wan2.2"
424+ model_name = config .model_name
425425 low_noise_transformer = None
426426 high_noise_transformer = None
427427 tokenizer = None
@@ -432,7 +432,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
432432 if load_transformer :
433433 with mesh :
434434 low_noise_transformer = cls .load_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config , subfolder = "transformer" )
435- if run_wan2_2 :
435+ if model_name == "wan2.2" :
436436 high_noise_transformer = cls .load_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config , subfolder = "transformer_2" )
437437 text_encoder = cls .load_text_encoder (config = config )
438438 tokenizer = cls .load_tokenizer (config = config )
@@ -457,7 +457,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
457457 )
458458
459459 pipeline .low_noise_transformer = cls .quantize_transformer (config , pipeline .low_noise_transformer , pipeline , mesh )
460- if run_wan2_2 :
460+ if model_name == "wan2.2" :
461461 pipeline .high_noise_transformer = cls .quantize_transformer (config , pipeline .high_noise_transformer , pipeline , mesh )
462462 return pipeline
463463
@@ -617,12 +617,12 @@ def __call__(
617617
618618 low_noise_graphdef , low_noise_state , low_noise_rest = nnx .split (self .low_noise_transformer , nnx .Param , ...)
619619 high_noise_graphdef , high_noise_state , high_noise_rest = None , None , None
620- if self .run_wan2_2 :
620+ if self .model_name == "wan2.2" :
621621 high_noise_graphdef , high_noise_state , high_noise_rest = nnx .split (self .high_noise_transformer , nnx .Param , ...)
622622
623623 p_run_inference = partial (
624624 run_inference ,
625- run_wan2_2 = self .run_wan2_2 ,
625+ model_name = self .model_name ,
626626 guidance_scale = guidance_scale ,
627627 guidance_scale_low = guidance_scale_low ,
628628 guidance_scale_high = guidance_scale_high ,
@@ -659,51 +659,27 @@ def __call__(
659659 return video
660660
661661
662- @partial (jax .jit , static_argnames = ("run_wan2_2 " , "guidance_scale" , "guidance_scale_low" , "guidance_scale_high" , "boundary" , "do_classifier_free_guidance " ))
662+ @partial (jax .jit , static_argnames = ("do_classifier_free_guidance " , "guidance_scale" ))
663663def transformer_forward_pass (
664- low_noise_graphdef ,
665- low_noise_state ,
666- low_noise_rest ,
667- high_noise_graphdef ,
668- high_noise_state ,
669- high_noise_rest ,
670- latents , timestep ,
664+ graphdef ,
665+ sharded_state ,
666+ rest_of_state ,
667+ latents ,
668+ timestep ,
671669 prompt_embeds ,
672- run_wan2_2 : bool ,
673- guidance_scale : float ,
674- guidance_scale_low : float ,
675- guidance_scale_high : float ,
676- boundary : int ,
677- do_classifier_free_guidance : bool ,
678- t : jnp .array ,
670+ do_classifier_free_guidance ,
671+ guidance_scale ,
679672):
680- low_noise_transformer = nnx .merge (low_noise_graphdef , low_noise_state , low_noise_rest )
681- noise_pred_low = low_noise_transformer (hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds )
682- noise_pred = noise_pred_low
683- current_guide_scale = guidance_scale
684- if run_wan2_2 :
685- high_noise_transformer = nnx .merge (high_noise_graphdef , high_noise_state , high_noise_rest )
686- noise_pred_high = high_noise_transformer (hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds )
687- use_high_noise = jnp .greater_equal (t , boundary )
688- noise_pred = jax .lax .cond (
689- use_high_noise ,
690- lambda : noise_pred_high ,
691- lambda : noise_pred_low ,
692- )
693- current_guide_scale = jax .lax .cond (
694- use_high_noise ,
695- lambda : guidance_scale_high ,
696- lambda : guidance_scale_low ,
697- )
698-
699- if do_classifier_free_guidance :
700- bsz = latents .shape [0 ] // 2
701- noise_uncond = noise_pred [bsz :]
702- noise_pred = noise_pred [:bsz ]
703- noise_pred = noise_uncond + current_guide_scale * (noise_pred - noise_uncond )
704- latents = latents [:bsz ]
673+ wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
674+ noise_pred = wan_transformer (hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds )
675+ if do_classifier_free_guidance :
676+ bsz = latents .shape [0 ] // 2
677+ noise_uncond = noise_pred [bsz :]
678+ noise_pred = noise_pred [:bsz ]
679+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
680+ latents = latents [:bsz ]
705681
706- return noise_pred , latents
682+ return noise_pred , latents
707683
708684def run_inference (
709685 low_noise_graphdef ,
@@ -715,7 +691,7 @@ def run_inference(
715691 latents : jnp .array ,
716692 prompt_embeds : jnp .array ,
717693 negative_prompt_embeds : jnp .array ,
718- run_wan2_2 : bool ,
694+ model_name : str ,
719695 guidance_scale : float ,
720696 guidance_scale_low : float ,
721697 guidance_scale_high : float ,
@@ -725,32 +701,40 @@ def run_inference(
725701 scheduler_state ,
726702):
727703 do_classifier_free_guidance = guidance_scale > 1.0
728- if run_wan2_2 :
704+ if model_name == "wan2.2" :
729705 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
730706 if do_classifier_free_guidance :
731707 prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
708+
709+ def low_noise_branch (operands ):
710+ latents , timestep , prompt_embeds = operands
711+ return transformer_forward_pass (
712+ low_noise_graphdef , low_noise_state , low_noise_rest ,
713+ latents , timestep , prompt_embeds ,
714+ do_classifier_free_guidance , guidance_scale_low
715+ )
716+
717+ def high_noise_branch (operands ):
718+ latents , timestep , prompt_embeds = operands
719+ return transformer_forward_pass (
720+ high_noise_graphdef , high_noise_state , high_noise_rest ,
721+ latents , timestep , prompt_embeds ,
722+ do_classifier_free_guidance , guidance_scale_high
723+ )
724+
732725 for step in range (num_inference_steps ):
733726 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
734727 if do_classifier_free_guidance :
735728 latents = jnp .concatenate ([latents ] * 2 )
736729 timestep = jnp .broadcast_to (t , latents .shape [0 ])
737730
738- noise_pred , latents = transformer_forward_pass (
739- low_noise_graphdef ,
740- low_noise_state ,
741- low_noise_rest ,
742- high_noise_graphdef ,
743- high_noise_state ,
744- high_noise_rest ,
745- latents , timestep ,
746- prompt_embeds ,
747- run_wan2_2 ,
748- guidance_scale ,
749- guidance_scale_low ,
750- guidance_scale_high ,
751- boundary ,
752- do_classifier_free_guidance ,
753- t
731+ use_high_noise = jnp .greater_equal (t , boundary )
732+
733+ noise_pred , latents = jax .lax .cond (
734+ use_high_noise ,
735+ high_noise_branch ,
736+ low_noise_branch ,
737+ (latents , timestep , prompt_embeds )
754738 )
755739
756740 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
0 commit comments