Skip to content

Commit 81e9dec

Browse files
committed
fix
1 parent 011729d commit 81e9dec

5 files changed

Lines changed: 11 additions & 39 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -912,15 +912,13 @@ def nearest_interp(src, target_len):
912912
indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32)
913913
return src[indices]
914914

915-
def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_step=None, model_type="T2V"):
915+
def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base):
916916
"""Initialize MagCache variables and interpolate ratios.
917917
918918
Args:
919919
num_inference_steps: Number of inference steps.
920920
retention_ratio: Retention ratio of unchanged steps.
921921
mag_ratios_base: Base magnitude ratios array or list.
922-
split_step: Step at which model switches (e.g. high -> low noise for 2.2).
923-
model_type: Pipeline mode ("T2V" or "I2V").
924922
"""
925923
import numpy as np
926924

@@ -953,8 +951,6 @@ def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_s
953951
cached_residual,
954952
skip_warmup,
955953
mag_ratios,
956-
split_step,
957-
model_type,
958954
)
959955

960956
def magcache_step(
@@ -964,10 +960,6 @@ def magcache_step(
964960
magcache_thresh,
965961
magcache_K,
966962
skip_warmup,
967-
split_step=None,
968-
model_type="T2V",
969-
num_steps=None,
970-
retention_ratio=0.2,
971963
):
972964
"""Update MagCache accumulated state and decide if to skip.
973965
@@ -978,10 +970,6 @@ def magcache_step(
978970
magcache_thresh: Error threshold.
979971
magcache_K: Max skip steps.
980972
skip_warmup: Warmup steps threshold.
981-
split_step: Optional step index where the model switches (e.g., from high to low noise).
982-
model_type: Pipeline type ("T2V" or "I2V").
983-
num_steps: Total inference steps, used to calculate post-split warmups.
984-
retention_ratio: Used to calculate post-split warmups.
985973
"""
986974
import numpy as np
987975

@@ -998,16 +986,8 @@ def magcache_step(
998986
cur_mag_ratio_uncond = mag_ratios[step * 2 + 1]
999987

1000988
use_magcache = True
1001-
if split_step is not None:
1002-
if model_type == "I2V":
1003-
if step < int(split_step + (num_steps - split_step) * retention_ratio):
1004-
use_magcache = False
1005-
else:
1006-
if step < int(split_step * retention_ratio) or (step <= ((num_steps - split_step) * retention_ratio + split_step) and step >= split_step):
1007-
use_magcache = False
1008-
else:
1009-
if step < skip_warmup:
1010-
use_magcache = False
989+
if step < skip_warmup:
990+
use_magcache = False
1011991

1012992
skip_blocks = False
1013993
if use_magcache:

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,6 @@ def run_inference_2_1(
246246
cached_residual,
247247
skip_warmup,
248248
mag_ratios,
249-
split_step,
250-
model_type,
251249
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
252250

253251
for step in range(num_inference_steps):
@@ -263,7 +261,7 @@ def run_inference_2_1(
263261
accumulated_steps_uncond,
264262
)
265263
skip_blocks, accumulated_state = magcache_step(
266-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
264+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
267265
)
268266
(
269267
accumulated_ratio_cond,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ def run_inference_2_2(
456456
if use_magcache and do_classifier_free_guidance:
457457
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
458458
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
459-
split_step = sum(step_uses_high)
460459

461460
(
462461
accumulated_ratio_cond,
@@ -468,9 +467,7 @@ def run_inference_2_2(
468467
cached_residual,
469468
skip_warmup,
470469
mag_ratios,
471-
split_step,
472-
model_type,
473-
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_step=split_step, model_type="T2V")
470+
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
474471

475472
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
476473

@@ -486,7 +483,7 @@ def run_inference_2_2(
486483
accumulated_steps_uncond,
487484
)
488485
skip_blocks, accumulated_state = magcache_step(
489-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
486+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
490487
)
491488
(
492489
accumulated_ratio_cond,
@@ -503,6 +500,7 @@ def run_inference_2_2(
503500
else:
504501
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
505502
guidance_scale = guidance_scale_low
503+
skip_blocks = False # Reference MagCache only caches high_noise_model
506504

507505
timestep = jnp.broadcast_to(t, bsz * 2)
508506

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ def run_inference_2_1_i2v(
295295
cached_residual,
296296
skip_warmup,
297297
mag_ratios,
298-
split_step,
299-
model_type,
300298
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
301299

302300
if do_classifier_free_guidance:
@@ -322,7 +320,7 @@ def run_inference_2_1_i2v(
322320
accumulated_steps_uncond,
323321
)
324322
skip_blocks, accumulated_state = magcache_step(
325-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
323+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
326324
)
327325
(
328326
accumulated_ratio_cond,

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,6 @@ def run_inference_2_2_i2v(
449449
if use_magcache and do_classifier_free_guidance:
450450
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
451451
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
452-
split_step = sum(step_uses_high)
453452

454453
(
455454
accumulated_ratio_cond,
@@ -461,9 +460,7 @@ def run_inference_2_2_i2v(
461460
cached_residual,
462461
skip_warmup,
463462
mag_ratios,
464-
split_step,
465-
model_type,
466-
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_step=split_step, model_type="I2V")
463+
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
467464

468465

469466
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
@@ -486,7 +483,7 @@ def run_inference_2_2_i2v(
486483
accumulated_steps_uncond,
487484
)
488485
skip_blocks, accumulated_state = magcache_step(
489-
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup, split_step=split_step, model_type=model_type, num_steps=num_inference_steps, retention_ratio=retention_ratio
486+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
490487
)
491488
(
492489
accumulated_ratio_cond,
@@ -503,6 +500,7 @@ def run_inference_2_2_i2v(
503500
else:
504501
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
505502
guidance_scale = guidance_scale_low
503+
skip_blocks = False # Reference MagCache only caches high_noise_model
506504

507505
timestep = jnp.broadcast_to(t, bsz * 2)
508506
latents_doubled = jnp.concatenate([latents, latents], axis=0)

0 commit comments

Comments
 (0)