@@ -111,14 +111,25 @@ def __call__(
111111 negative_prompt_embeds : jax .Array = None ,
112112 vae_only : bool = False ,
113113 use_cfg_cache : bool = False ,
114+ use_sen_cache : bool = False ,
114115 ):
116+ if use_cfg_cache and use_sen_cache :
117+ raise ValueError ("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one." )
118+
115119 if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
116120 raise ValueError (
117121 f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118122 f"(got { guidance_scale_low } , { guidance_scale_high } ). "
119123 "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120124 )
121125
126+ if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
127+ raise ValueError (
128+ f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
129+ f"(got { guidance_scale_low } , { guidance_scale_high } ). "
130+ "SenCache requires classifier-free guidance to be enabled for both transformer phases."
131+ )
132+
122133 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
123134 prompt ,
124135 negative_prompt ,
@@ -148,6 +159,7 @@ def __call__(
148159 scheduler = self .scheduler ,
149160 scheduler_state = scheduler_state ,
150161 use_cfg_cache = use_cfg_cache ,
162+ use_sen_cache = use_sen_cache ,
151163 height = height ,
152164 )
153165
@@ -184,22 +196,104 @@ def run_inference_2_2(
184196 scheduler : FlaxUniPCMultistepScheduler ,
185197 scheduler_state ,
186198 use_cfg_cache : bool = False ,
199+ use_sen_cache : bool = False ,
187200 height : int = 480 ,
188201):
189- """Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190-
191- Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192- - High-noise phase (t >= boundary): always full CFG — short phase, critical
193- for establishing video structure.
194- - Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195- steps, FFT frequency-domain compensation on cache steps (batch×1).
196- - Boundary transition: mandatory full CFG step to populate cache for the
197- low-noise transformer.
198- - FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
202+ """Denoising loop for WAN 2.2 T2V with optional caching acceleration.
203+
204+ Supports two caching strategies:
205+
206+ 1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207+ Caches the unconditional branch and uses FFT frequency-domain compensation.
208+
209+ 2. SenCache (use_sen_cache=True) — Sensitivity-aware caching:
210+ Measures output sensitivity after each full forward pass. When sensitivity
211+ is low (model output is stable), skips the entire transformer and reuses
212+ the cached noise prediction. Naturally handles MoE expert boundaries by
213+ detecting high sensitivity at transition points.
199214 """
200215 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
201216 bsz = latents .shape [0 ]
202217
218+ # ── SenCache path ──
219+ if use_sen_cache and do_classifier_free_guidance :
220+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
221+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
222+
223+ # Resolution-dependent SenCache config
224+ if height >= 720 :
225+ sen_threshold = 0.06 # tighter for higher resolution
226+ warmup_ratio = 0.10
227+ max_consecutive_cache = 2
228+ else :
229+ sen_threshold = 0.08
230+ warmup_ratio = 0.08
231+ max_consecutive_cache = 3
232+
233+ warmup_steps = max (2 , int (num_inference_steps * warmup_ratio ))
234+
235+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
236+
237+ # SenCache state
238+ prev_noise_pred = None # last full-computation noise prediction
239+ sensitivity = float ('inf' ) # measured relative output change
240+ consecutive_cached = 0 # consecutive steps using cache
241+ cache_count = 0
242+
243+ for step in range (num_inference_steps ):
244+ t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
245+
246+ # Select transformer and guidance scale
247+ if step_uses_high [step ]:
248+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
249+ guidance_scale = guidance_scale_high
250+ else :
251+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
252+ guidance_scale = guidance_scale_low
253+
254+ # Caching decision
255+ is_warmup = step < warmup_steps
256+ is_boundary = step > 0 and step_uses_high [step ] != step_uses_high [step - 1 ]
257+ should_cache = (
258+ not is_warmup
259+ and not is_boundary
260+ and prev_noise_pred is not None
261+ and sensitivity < sen_threshold
262+ and consecutive_cached < max_consecutive_cache
263+ )
264+
265+ if should_cache :
266+ # ── Cache step: reuse previous noise prediction ──
267+ noise_pred = prev_noise_pred
268+ consecutive_cached += 1
269+ cache_count += 1
270+ else :
271+ # ── Full CFG step ──
272+ latents_doubled = jnp .concatenate ([latents ] * 2 )
273+ timestep = jnp .broadcast_to (t , bsz * 2 )
274+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
275+ graphdef , state , rest ,
276+ latents_doubled , timestep , prompt_embeds_combined ,
277+ guidance_scale = guidance_scale ,
278+ )
279+
280+ # Measure sensitivity: relative output change since last full step
281+ if prev_noise_pred is not None :
282+ output_diff = jnp .mean (jnp .abs (noise_pred - prev_noise_pred ))
283+ output_magnitude = jnp .mean (jnp .abs (noise_pred )) + 1e-8
284+ sensitivity = float (output_diff / output_magnitude )
285+ else :
286+ sensitivity = float ('inf' )
287+
288+ prev_noise_pred = noise_pred
289+ consecutive_cached = 0
290+
291+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
292+
293+ print (f"[SenCache] Cached { cache_count } /{ num_inference_steps } steps "
294+ f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)" )
295+ return latents
296+
203297 # ── CFG cache path ──
204298 if use_cfg_cache and do_classifier_free_guidance :
205299 # Get timesteps as numpy for Python-level scheduling decisions
0 commit comments