Skip to content

Commit 2bae741

Browse files
author
James Huang
committed
implement sencache
Signed-off-by: James Huang <shyhuang@google.com>
1 parent 0aea69b commit 2bae741

2 files changed

Lines changed: 107 additions & 11 deletions

File tree

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,10 @@ guidance_scale_high: 4.0
302302
# timestep to switch between low noise and high noise transformer
303303
boundary_ratio: 0.875
304304

305-
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
305+
# Diffusion CFG cache (FasterCache-style)
306306
use_cfg_cache: False
307+
# SenCache: sensitivity-aware adaptive caching (Haghighi & Alahi, 2026)
308+
use_sen_cache: False
307309

308310
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
309311
guidance_rescale: 0.0

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,25 @@ def __call__(
111111
negative_prompt_embeds: jax.Array = None,
112112
vae_only: bool = False,
113113
use_cfg_cache: bool = False,
114+
use_sen_cache: bool = False,
114115
):
116+
if use_cfg_cache and use_sen_cache:
117+
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")
118+
115119
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
116120
raise ValueError(
117121
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118122
f"(got {guidance_scale_low}, {guidance_scale_high}). "
119123
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120124
)
121125

126+
if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
127+
raise ValueError(
128+
f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
129+
f"(got {guidance_scale_low}, {guidance_scale_high}). "
130+
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
131+
)
132+
122133
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
123134
prompt,
124135
negative_prompt,
@@ -148,6 +159,7 @@ def __call__(
148159
scheduler=self.scheduler,
149160
scheduler_state=scheduler_state,
150161
use_cfg_cache=use_cfg_cache,
162+
use_sen_cache=use_sen_cache,
151163
height=height,
152164
)
153165

@@ -184,22 +196,104 @@ def run_inference_2_2(
184196
scheduler: FlaxUniPCMultistepScheduler,
185197
scheduler_state,
186198
use_cfg_cache: bool = False,
199+
use_sen_cache: bool = False,
187200
height: int = 480,
188201
):
189-
"""Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190-
191-
Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192-
- High-noise phase (t >= boundary): always full CFG — short phase, critical
193-
for establishing video structure.
194-
- Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195-
steps, FFT frequency-domain compensation on cache steps (batch×1).
196-
- Boundary transition: mandatory full CFG step to populate cache for the
197-
low-noise transformer.
198-
- FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
202+
"""Denoising loop for WAN 2.2 T2V with optional caching acceleration.
203+
204+
Supports two caching strategies:
205+
206+
1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207+
Caches the unconditional branch and uses FFT frequency-domain compensation.
208+
209+
2. SenCache (use_sen_cache=True) — Sensitivity-aware caching:
210+
Measures output sensitivity after each full forward pass. When sensitivity
211+
is low (model output is stable), skips the entire transformer and reuses
212+
the cached noise prediction. Naturally handles MoE expert boundaries by
213+
detecting high sensitivity at transition points.
199214
"""
200215
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
201216
bsz = latents.shape[0]
202217

218+
# ── SenCache path ──
219+
if use_sen_cache and do_classifier_free_guidance:
220+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
221+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
222+
223+
# Resolution-dependent SenCache config
224+
if height >= 720:
225+
sen_threshold = 0.06 # tighter for higher resolution
226+
warmup_ratio = 0.10
227+
max_consecutive_cache = 2
228+
else:
229+
sen_threshold = 0.08
230+
warmup_ratio = 0.08
231+
max_consecutive_cache = 3
232+
233+
warmup_steps = max(2, int(num_inference_steps * warmup_ratio))
234+
235+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
236+
237+
# SenCache state
238+
prev_noise_pred = None # last full-computation noise prediction
239+
sensitivity = float('inf') # measured relative output change
240+
consecutive_cached = 0 # consecutive steps using cache
241+
cache_count = 0
242+
243+
for step in range(num_inference_steps):
244+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
245+
246+
# Select transformer and guidance scale
247+
if step_uses_high[step]:
248+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
249+
guidance_scale = guidance_scale_high
250+
else:
251+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
252+
guidance_scale = guidance_scale_low
253+
254+
# Caching decision
255+
is_warmup = step < warmup_steps
256+
is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
257+
should_cache = (
258+
not is_warmup
259+
and not is_boundary
260+
and prev_noise_pred is not None
261+
and sensitivity < sen_threshold
262+
and consecutive_cached < max_consecutive_cache
263+
)
264+
265+
if should_cache:
266+
# ── Cache step: reuse previous noise prediction ──
267+
noise_pred = prev_noise_pred
268+
consecutive_cached += 1
269+
cache_count += 1
270+
else:
271+
# ── Full CFG step ──
272+
latents_doubled = jnp.concatenate([latents] * 2)
273+
timestep = jnp.broadcast_to(t, bsz * 2)
274+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
275+
graphdef, state, rest,
276+
latents_doubled, timestep, prompt_embeds_combined,
277+
guidance_scale=guidance_scale,
278+
)
279+
280+
# Measure sensitivity: relative output change since last full step
281+
if prev_noise_pred is not None:
282+
output_diff = jnp.mean(jnp.abs(noise_pred - prev_noise_pred))
283+
output_magnitude = jnp.mean(jnp.abs(noise_pred)) + 1e-8
284+
sensitivity = float(output_diff / output_magnitude)
285+
else:
286+
sensitivity = float('inf')
287+
288+
prev_noise_pred = noise_pred
289+
consecutive_cached = 0
290+
291+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
292+
293+
print(f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
294+
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)")
295+
return latents
296+
203297
# ── CFG cache path ──
204298
if use_cfg_cache and do_classifier_free_guidance:
205299
# Get timesteps as numpy for Python-level scheduling decisions

0 commit comments

Comments
 (0)