Skip to content

Commit 1be0361

Browse files
committed
double noise computation fixed
1 parent 8a752e7 commit 1be0361

1 file changed

Lines changed: 51 additions & 67 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 51 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
self.devices_array = devices_array
214214
self.mesh = mesh
215215
self.config = config
216-
self.run_wan2_2 = config.model_name == "wan2.2"
216+
self.model_name = config.model_name
217217

218218
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
219219
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
@@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
379379
mesh = Mesh(devices_array, config.mesh_axes)
380380
rng = jax.random.key(config.seed)
381381
rngs = nnx.Rngs(rng)
382-
run_wan2_2 = config.model_name == "wan2.2"
382+
model_name = config.model_name
383383
low_noise_transformer = None
384384
high_noise_transformer = None
385385
tokenizer = None
@@ -390,7 +390,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
390390
if load_transformer:
391391
with mesh:
392392
low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer")
393-
if run_wan2_2:
393+
if model_name == "wan2.2":
394394
high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2")
395395

396396
text_encoder = cls.load_text_encoder(config=config)
@@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
421421
mesh = Mesh(devices_array, config.mesh_axes)
422422
rng = jax.random.key(config.seed)
423423
rngs = nnx.Rngs(rng)
424-
run_wan2_2 = config.model_name == "wan2.2"
424+
model_name = config.model_name
425425
low_noise_transformer = None
426426
high_noise_transformer = None
427427
tokenizer = None
@@ -432,7 +432,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
432432
if load_transformer:
433433
with mesh:
434434
low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer")
435-
if run_wan2_2:
435+
if model_name == "wan2.2":
436436
high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2")
437437
text_encoder = cls.load_text_encoder(config=config)
438438
tokenizer = cls.load_tokenizer(config=config)
@@ -457,7 +457,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
457457
)
458458

459459
pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh)
460-
if run_wan2_2:
460+
if model_name == "wan2.2":
461461
pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh)
462462
return pipeline
463463

@@ -617,12 +617,12 @@ def __call__(
617617

618618
low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...)
619619
high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None
620-
if self.run_wan2_2:
620+
if self.model_name == "wan2.2":
621621
high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...)
622622

623623
p_run_inference = partial(
624624
run_inference,
625-
run_wan2_2=self.run_wan2_2,
625+
model_name=self.model_name,
626626
guidance_scale=guidance_scale,
627627
guidance_scale_low=guidance_scale_low,
628628
guidance_scale_high=guidance_scale_high,
@@ -659,51 +659,27 @@ def __call__(
659659
return video
660660

661661

662-
@partial(jax.jit, static_argnames=("run_wan2_2", "guidance_scale", "guidance_scale_low", "guidance_scale_high", "boundary", "do_classifier_free_guidance"))
662+
@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale"))
663663
def transformer_forward_pass(
664-
low_noise_graphdef,
665-
low_noise_state,
666-
low_noise_rest,
667-
high_noise_graphdef,
668-
high_noise_state,
669-
high_noise_rest,
670-
latents, timestep,
664+
graphdef,
665+
sharded_state,
666+
rest_of_state,
667+
latents,
668+
timestep,
671669
prompt_embeds,
672-
run_wan2_2: bool,
673-
guidance_scale: float,
674-
guidance_scale_low: float,
675-
guidance_scale_high: float,
676-
boundary: int,
677-
do_classifier_free_guidance: bool,
678-
t: jnp.array,
670+
do_classifier_free_guidance,
671+
guidance_scale,
679672
):
680-
low_noise_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest)
681-
noise_pred_low = low_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds)
682-
noise_pred = noise_pred_low
683-
current_guide_scale = guidance_scale
684-
if run_wan2_2:
685-
high_noise_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest)
686-
noise_pred_high = high_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds)
687-
use_high_noise = jnp.greater_equal(t, boundary)
688-
noise_pred = jax.lax.cond(
689-
use_high_noise,
690-
lambda: noise_pred_high,
691-
lambda: noise_pred_low,
692-
)
693-
current_guide_scale = jax.lax.cond(
694-
use_high_noise,
695-
lambda: guidance_scale_high,
696-
lambda: guidance_scale_low,
697-
)
698-
699-
if do_classifier_free_guidance:
700-
bsz = latents.shape[0] // 2
701-
noise_uncond = noise_pred[bsz:]
702-
noise_pred = noise_pred[:bsz]
703-
noise_pred = noise_uncond + current_guide_scale * (noise_pred - noise_uncond)
704-
latents = latents[:bsz]
673+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
674+
noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds)
675+
if do_classifier_free_guidance:
676+
bsz = latents.shape[0] // 2
677+
noise_uncond = noise_pred[bsz:]
678+
noise_pred = noise_pred[:bsz]
679+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
680+
latents = latents[:bsz]
705681

706-
return noise_pred, latents
682+
return noise_pred, latents
707683

708684
def run_inference(
709685
low_noise_graphdef,
@@ -715,7 +691,7 @@ def run_inference(
715691
latents: jnp.array,
716692
prompt_embeds: jnp.array,
717693
negative_prompt_embeds: jnp.array,
718-
run_wan2_2: bool,
694+
model_name: str,
719695
guidance_scale: float,
720696
guidance_scale_low: float,
721697
guidance_scale_high: float,
@@ -725,32 +701,40 @@ def run_inference(
725701
scheduler_state,
726702
):
727703
do_classifier_free_guidance = guidance_scale > 1.0
728-
if run_wan2_2:
704+
if model_name == "wan2.2":
729705
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
730706
if do_classifier_free_guidance:
731707
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
708+
709+
def low_noise_branch(operands):
710+
latents, timestep, prompt_embeds = operands
711+
return transformer_forward_pass(
712+
low_noise_graphdef, low_noise_state, low_noise_rest,
713+
latents, timestep, prompt_embeds,
714+
do_classifier_free_guidance, guidance_scale_low
715+
)
716+
717+
def high_noise_branch(operands):
718+
latents, timestep, prompt_embeds = operands
719+
return transformer_forward_pass(
720+
high_noise_graphdef, high_noise_state, high_noise_rest,
721+
latents, timestep, prompt_embeds,
722+
do_classifier_free_guidance, guidance_scale_high
723+
)
724+
732725
for step in range(num_inference_steps):
733726
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
734727
if do_classifier_free_guidance:
735728
latents = jnp.concatenate([latents] * 2)
736729
timestep = jnp.broadcast_to(t, latents.shape[0])
737730

738-
noise_pred, latents = transformer_forward_pass(
739-
low_noise_graphdef,
740-
low_noise_state,
741-
low_noise_rest,
742-
high_noise_graphdef,
743-
high_noise_state,
744-
high_noise_rest,
745-
latents, timestep,
746-
prompt_embeds,
747-
run_wan2_2,
748-
guidance_scale,
749-
guidance_scale_low,
750-
guidance_scale_high,
751-
boundary,
752-
do_classifier_free_guidance,
753-
t
731+
use_high_noise = jnp.greater_equal(t, boundary)
732+
733+
noise_pred, latents = jax.lax.cond(
734+
use_high_noise,
735+
high_noise_branch,
736+
low_noise_branch,
737+
(latents, timestep, prompt_embeds)
754738
)
755739

756740
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

0 commit comments

Comments
 (0)