Skip to content

Commit 773aa33

Browse files
committed
removing hardcoded magcache params values
1 parent e2beb8c commit 773aa33

6 files changed

Lines changed: 42 additions & 14 deletions

File tree

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ use_cfg_cache: False
307307
use_magcache: True
308308
magcache_thresh: 0.06
309309
magcache_K: 2
310-
retention_ratio: 0.2
310+
retention_ratio: 0.4
311311
mag_ratios_base: [1.0, 1.0, 1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
312312

313313
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ guidance_scale_high: 4.0
296296
# The timestep threshold. If `t` is at or above this value,
297297
# the `high_noise_model` is considered as the required model.
298298
# timestep to switch between low noise and high noise transformer
299-
boundary_ratio: 0.875
299+
boundary_ratio: 0.9
300300

301301
# Diffusion CFG cache (FasterCache-style)
302302
use_cfg_cache: False

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,17 @@ def __call__(
9292
vae_only: bool = False,
9393
use_cfg_cache: bool = False,
9494
use_magcache: bool = False,
95-
magcache_thresh: float = 0.12,
96-
magcache_K: int = 2,
97-
retention_ratio: float = 0.2,
95+
magcache_thresh: Optional[float] = None,
96+
magcache_K: Optional[int] = None,
97+
retention_ratio: Optional[float] = None,
9898
):
99+
if magcache_thresh is None:
100+
magcache_thresh = getattr(self.config, "magcache_thresh", 0.12)
101+
if magcache_K is None:
102+
magcache_K = getattr(self.config, "magcache_K", 2)
103+
if retention_ratio is None:
104+
retention_ratio = getattr(self.config, "retention_ratio", 0.2)
105+
99106
if use_cfg_cache and guidance_scale <= 1.0:
100107
raise ValueError(
101108
f"use_cfg_cache=True requires guidance_scale > 1.0 (got {guidance_scale}). "

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,17 @@ def __call__(
113113
use_cfg_cache: bool = False,
114114
use_sen_cache: bool = False,
115115
use_magcache: bool = False,
116-
magcache_thresh: float = 0.04,
117-
magcache_K: int = 2,
118-
retention_ratio: float = 0.2,
116+
magcache_thresh: Optional[float] = None,
117+
magcache_K: Optional[int] = None,
118+
retention_ratio: Optional[float] = None,
119119
):
120+
if magcache_thresh is None:
121+
magcache_thresh = getattr(self.config, "magcache_thresh", 0.04)
122+
if magcache_K is None:
123+
magcache_K = getattr(self.config, "magcache_K", 2)
124+
if retention_ratio is None:
125+
retention_ratio = getattr(self.config, "retention_ratio", 0.2)
126+
120127
if use_cfg_cache and use_sen_cache:
121128
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")
122129

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,17 @@ def __call__(
150150
output_type: Optional[str] = "np",
151151
rng: Optional[jax.Array] = None,
152152
use_magcache: bool = False,
153-
magcache_thresh: float = 0.04,
154-
magcache_K: int = 2,
155-
retention_ratio: float = 0.2,
153+
magcache_thresh: Optional[float] = None,
154+
magcache_K: Optional[int] = None,
155+
retention_ratio: Optional[float] = None,
156156
):
157+
if magcache_thresh is None:
158+
magcache_thresh = getattr(self.config, "magcache_thresh", 0.04)
159+
if magcache_K is None:
160+
magcache_K = getattr(self.config, "magcache_K", 2)
161+
if retention_ratio is None:
162+
retention_ratio = getattr(self.config, "retention_ratio", 0.2)
163+
157164
height = height or self.config.height
158165
width = width or self.config.width
159166
num_frames = num_frames or self.config.num_frames

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,17 @@ def __call__(
168168
rng: Optional[jax.Array] = None,
169169
use_cfg_cache: bool = False,
170170
use_magcache: bool = False,
171-
magcache_thresh: float = 0.04,
172-
magcache_K: int = 2,
173-
retention_ratio: float = 0.2,
171+
magcache_thresh: Optional[float] = None,
172+
magcache_K: Optional[int] = None,
173+
retention_ratio: Optional[float] = None,
174174
):
175+
if magcache_thresh is None:
176+
magcache_thresh = getattr(self.config, "magcache_thresh", 0.04)
177+
if magcache_K is None:
178+
magcache_K = getattr(self.config, "magcache_K", 2)
179+
if retention_ratio is None:
180+
retention_ratio = getattr(self.config, "retention_ratio", 0.2)
181+
175182
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
176183
raise ValueError(
177184
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "

0 commit comments

Comments
 (0)