@@ -199,13 +199,13 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
199199 wan_vae = nnx .merge (graphdef , params )
200200 p_create_sharded_logical_model = partial (create_sharded_logical_model , logical_axis_rules = config .logical_axis_rules )
201201 # Shard
202- with mesh :
202+ with mesh , nn_partitioning . axis_rules ( config . logical_axis_rules ) :
203203 wan_vae = p_create_sharded_logical_model (model = wan_vae )
204204 return wan_vae , vae_cache
205205
206206 @classmethod
207207 def load_transformer (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
208- with mesh :
208+ with mesh , nn_partitioning . axis_rules ( config . logical_axis_rules ) :
209209 wan_transformer = create_sharded_logical_transformer (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
210210 return wan_transformer
211211
@@ -468,11 +468,17 @@ def run_inference(
468468 slg_end : float = 1.0 ,
469469):
470470 do_classifier_free_guidance = guidance_scale > 1.0
471+ if do_classifier_free_guidance :
472+ prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
471473 for step in range (num_inference_steps ):
472474 slg_mask = jnp .zeros (num_transformer_layers , dtype = jnp .bool_ )
473475 if slg_layers and int (slg_start * num_inference_steps ) <= step < int (slg_end * num_inference_steps ):
474476 slg_mask = slg_mask .at [jnp .array (slg_layers )].set (True )
475477 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
478+ # get original batch size before concat in case of cfg.
479+ bsz = latents .shape [0 ]
480+ if do_classifier_free_guidance :
481+ latents = jnp .concatenate ([latents ] * 2 )
476482 timestep = jnp .broadcast_to (t , latents .shape [0 ])
477483
478484 noise_pred = transformer_forward_pass (
@@ -482,21 +488,14 @@ def run_inference(
482488 latents ,
483489 timestep ,
484490 prompt_embeds ,
485- is_uncond = jnp .array (False , dtype = jnp .bool_ ),
491+ is_uncond = jnp .array (True , dtype = jnp .bool_ ),
486492 slg_mask = slg_mask ,
487493 )
488494
489495 if do_classifier_free_guidance :
490- noise_uncond = transformer_forward_pass (
491- graphdef ,
492- sharded_state ,
493- rest_of_state ,
494- latents ,
495- timestep ,
496- negative_prompt_embeds ,
497- is_uncond = jnp .array (True , dtype = jnp .bool_ ),
498- slg_mask = slg_mask ,
499- )
496+ noise_uncond = noise_pred [bsz :]
497+ noise_pred = noise_pred [:bsz ]
500498 noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
499+ latents = latents [:bsz ]
501500 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
502501 return latents
0 commit comments