Skip to content

Commit 2d8b31d

Browse files
committed
wan 2.2 fixes
1 parent 81e9dec commit 2d8b31d

3 files changed

Lines changed: 23 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,8 @@ def magcache_step(
959959
accumulated_state,
960960
magcache_thresh,
961961
magcache_K,
962-
skip_warmup,
962+
skip_warmup=0,
963+
use_magcache=None,
963964
):
964965
"""Update MagCache accumulated state and decide if to skip.
965966
@@ -970,6 +971,7 @@ def magcache_step(
970971
magcache_thresh: Error threshold.
971972
magcache_K: Max skip steps.
972973
skip_warmup: Warmup steps threshold.
974+
use_magcache: Optional manual override boolean to enable/disable cache for this step.
973975
"""
974976
import numpy as np
975977

@@ -985,9 +987,10 @@ def magcache_step(
985987
cur_mag_ratio_cond = mag_ratios[step * 2]
986988
cur_mag_ratio_uncond = mag_ratios[step * 2 + 1]
987989

988-
use_magcache = True
989-
if step < skip_warmup:
990-
use_magcache = False
990+
if use_magcache is None:
991+
use_magcache = True
992+
if step < skip_warmup:
993+
use_magcache = False
991994

992995
skip_blocks = False
993996
if use_magcache:

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,17 @@ def run_inference_2_2(
471471

472472
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
473473

474+
high_noise_steps = sum(step_uses_high)
475+
474476
for step in range(num_inference_steps):
475477
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
476478

479+
use_magcache = True
480+
if step < int(high_noise_steps * retention_ratio):
481+
use_magcache = False
482+
elif step >= high_noise_steps and step <= int(high_noise_steps + (num_inference_steps - high_noise_steps) * retention_ratio):
483+
use_magcache = False
484+
477485
accumulated_state = (
478486
accumulated_ratio_cond,
479487
accumulated_ratio_uncond,
@@ -483,7 +491,7 @@ def run_inference_2_2(
483491
accumulated_steps_uncond,
484492
)
485493
skip_blocks, accumulated_state = magcache_step(
486-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
494+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, use_magcache=use_magcache
487495
)
488496
(
489497
accumulated_ratio_cond,
@@ -500,7 +508,6 @@ def run_inference_2_2(
500508
else:
501509
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
502510
guidance_scale = guidance_scale_low
503-
skip_blocks = False # Reference MagCache only caches high_noise_model
504511

505512
timestep = jnp.broadcast_to(t, bsz * 2)
506513

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,15 @@ def run_inference_2_2_i2v(
471471

472472
condition_combined = jnp.concatenate([condition] * 2)
473473

474+
high_noise_steps = sum(step_uses_high)
475+
474476
for step in range(num_inference_steps):
475477
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
476478

479+
use_magcache = True
480+
if step < int(high_noise_steps + (num_inference_steps - high_noise_steps) * retention_ratio):
481+
use_magcache = False
482+
477483
accumulated_state = (
478484
accumulated_ratio_cond,
479485
accumulated_ratio_uncond,
@@ -483,7 +489,7 @@ def run_inference_2_2_i2v(
483489
accumulated_steps_uncond,
484490
)
485491
skip_blocks, accumulated_state = magcache_step(
486-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
492+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, use_magcache=use_magcache
487493
)
488494
(
489495
accumulated_ratio_cond,
@@ -500,7 +506,6 @@ def run_inference_2_2_i2v(
500506
else:
501507
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
502508
guidance_scale = guidance_scale_low
503-
skip_blocks = False # Reference MagCache only caches high_noise_model
504509

505510
timestep = jnp.broadcast_to(t, bsz * 2)
506511
latents_doubled = jnp.concatenate([latents, latents], axis=0)

0 commit comments

Comments
 (0)