Skip to content

Commit 627ef88

Browse files
committed
correction needed for wan 2.2
1 parent ab31f35 commit 627ef88

5 files changed

Lines changed: 48 additions & 12 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,13 +912,15 @@ def nearest_interp(src, target_len):
912912
indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32)
913913
return src[indices]
914914

915-
def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base):
915+
def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_step=None, model_type="T2V"):
916916
"""Initialize MagCache variables and interpolate ratios.
917917
918918
Args:
919919
num_inference_steps: Number of inference steps.
920920
retention_ratio: Retention ratio of unchanged steps.
921921
mag_ratios_base: Base magnitude ratios array or list.
922+
split_step: Step at which model switches (e.g. high -> low noise for 2.2).
923+
model_type: Pipeline mode ("T2V" or "I2V").
922924
"""
923925
import numpy as np
924926

@@ -951,6 +953,8 @@ def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base):
951953
cached_residual,
952954
skip_warmup,
953955
mag_ratios,
956+
split_step,
957+
model_type,
954958
)
955959

956960
def magcache_step(
@@ -960,6 +964,10 @@ def magcache_step(
960964
magcache_thresh,
961965
magcache_K,
962966
skip_warmup,
967+
split_step=None,
968+
model_type="T2V",
969+
num_steps=None,
970+
retention_ratio=0.2,
963971
):
964972
"""Update MagCache accumulated state and decide if to skip.
965973
@@ -970,6 +978,10 @@ def magcache_step(
970978
magcache_thresh: Error threshold.
971979
magcache_K: Max skip steps.
972980
skip_warmup: Warmup steps threshold.
981+
split_step: Optional step index where the model switches (e.g., from high to low noise).
982+
model_type: Pipeline type ("T2V" or "I2V").
983+
num_steps: Total inference steps, used to calculate post-split warmups.
984+
retention_ratio: Used to calculate post-split warmups.
973985
"""
974986
import numpy as np
975987

@@ -985,8 +997,20 @@ def magcache_step(
985997
cur_mag_ratio_cond = mag_ratios[step * 2]
986998
cur_mag_ratio_uncond = mag_ratios[step * 2 + 1]
987999

1000+
use_magcache = True
1001+
if split_step is not None:
1002+
if model_type == "I2V":
1003+
if step < int(split_step + (num_steps - split_step) * retention_ratio):
1004+
use_magcache = False
1005+
else:
1006+
if step < int(split_step * retention_ratio) or (step <= ((num_steps - split_step) * retention_ratio + split_step) and step >= split_step):
1007+
use_magcache = False
1008+
else:
1009+
if step < skip_warmup:
1010+
use_magcache = False
1011+
9881012
skip_blocks = False
989-
if step >= skip_warmup:
1013+
if use_magcache:
9901014
new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
9911015
new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
9921016

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ def run_inference_2_1(
246246
cached_residual,
247247
skip_warmup,
248248
mag_ratios,
249+
split_step,
250+
model_type,
249251
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
250252

251253
for step in range(num_inference_steps):
@@ -261,7 +263,7 @@ def run_inference_2_1(
261263
accumulated_steps_uncond,
262264
)
263265
skip_blocks, accumulated_state = magcache_step(
264-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
266+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
265267
)
266268
(
267269
accumulated_ratio_cond,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,10 @@ def run_inference_2_2(
454454

455455
# ── MagCache path ──
456456
if use_magcache and do_classifier_free_guidance:
457+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
458+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
459+
split_step = sum(step_uses_high)
460+
457461
(
458462
accumulated_ratio_cond,
459463
accumulated_ratio_uncond,
@@ -464,10 +468,10 @@ def run_inference_2_2(
464468
cached_residual,
465469
skip_warmup,
466470
mag_ratios,
467-
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
471+
split_step,
472+
model_type,
473+
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_step=split_step, model_type=self.config.model_type)
468474

469-
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
470-
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
471475
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
472476

473477
for step in range(num_inference_steps):
@@ -482,7 +486,7 @@ def run_inference_2_2(
482486
accumulated_steps_uncond,
483487
)
484488
skip_blocks, accumulated_state = magcache_step(
485-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
489+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
486490
)
487491
(
488492
accumulated_ratio_cond,

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def run_inference_2_1_i2v(
295295
cached_residual,
296296
skip_warmup,
297297
mag_ratios,
298+
split_step,
299+
model_type,
298300
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
299301

300302
if do_classifier_free_guidance:
@@ -320,7 +322,7 @@ def run_inference_2_1_i2v(
320322
accumulated_steps_uncond,
321323
)
322324
skip_blocks, accumulated_state = magcache_step(
323-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
325+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
324326
)
325327
(
326328
accumulated_ratio_cond,

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ def run_inference_2_2_i2v(
447447

448448
# ── MagCache path ──
449449
if use_magcache and do_classifier_free_guidance:
450+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
451+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
452+
split_step = sum(step_uses_high)
453+
450454
(
451455
accumulated_ratio_cond,
452456
accumulated_ratio_uncond,
@@ -457,10 +461,10 @@ def run_inference_2_2_i2v(
457461
cached_residual,
458462
skip_warmup,
459463
mag_ratios,
460-
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
464+
split_step,
465+
model_type,
466+
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_step=split_step, model_type=self.config.model_type)
461467

462-
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
463-
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
464468

465469
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
466470
if image_embeds is not None:
@@ -482,7 +486,7 @@ def run_inference_2_2_i2v(
482486
accumulated_steps_uncond,
483487
)
484488
skip_blocks, accumulated_state = magcache_step(
485-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
489+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
486490
)
487491
(
488492
accumulated_ratio_cond,

0 commit comments

Comments
 (0)