1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from .wan_pipeline import WanPipeline , transformer_forward_pass
15+ from .wan_pipeline import WanPipeline , transformer_forward_pass , transformer_forward_pass_full_cfg , transformer_forward_pass_cfg_cache
1616from ...models .wan .transformers .transformer_wan import WanModel
1717from typing import List , Union , Optional
1818from ...pyconfig import HyperParameters
2121from flax .linen import partitioning as nn_partitioning
2222import jax
2323import jax .numpy as jnp
24+ import numpy as np
2425from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2526
2627
@@ -32,7 +33,7 @@ def __init__(
3233 config : HyperParameters ,
3334 low_noise_transformer : Optional [WanModel ],
3435 high_noise_transformer : Optional [WanModel ],
35- ** kwargs
36+ ** kwargs ,
3637 ):
3738 super ().__init__ (config = config , ** kwargs )
3839 self .low_noise_transformer = low_noise_transformer
@@ -109,7 +110,15 @@ def __call__(
109110 prompt_embeds : jax .Array = None ,
110111 negative_prompt_embeds : jax .Array = None ,
111112 vae_only : bool = False ,
113+ use_cfg_cache : bool = False ,
112114 ):
115+ if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
116+ raise ValueError (
117+ f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118+ f"(got { guidance_scale_low } , { guidance_scale_high } ). "
119+ "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120+ )
121+
113122 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
114123 prompt ,
115124 negative_prompt ,
@@ -138,6 +147,8 @@ def __call__(
138147 num_inference_steps = num_inference_steps ,
139148 scheduler = self .scheduler ,
140149 scheduler_state = scheduler_state ,
150+ use_cfg_cache = use_cfg_cache ,
151+ height = height ,
141152 )
142153
143154 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -172,51 +183,184 @@ def run_inference_2_2(
172183 num_inference_steps : int ,
173184 scheduler : FlaxUniPCMultistepScheduler ,
174185 scheduler_state ,
186+ use_cfg_cache : bool = False ,
187+ height : int = 480 ,
175188):
189+ """Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190+
191+ Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192+ - High-noise phase (t >= boundary): always full CFG — short phase, critical
193+ for establishing video structure.
194+ - Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195+ steps, FFT frequency-domain compensation on cache steps (batch×1).
196+ - Boundary transition: mandatory full CFG step to populate cache for the
197+ low-noise transformer.
198+ - FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
199+ """
176200 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
177- if do_classifier_free_guidance :
178- prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
179-
180- def low_noise_branch (operands ):
181- latents , timestep , prompt_embeds = operands
182- return transformer_forward_pass (
183- low_noise_graphdef ,
184- low_noise_state ,
185- low_noise_rest ,
186- latents ,
187- timestep ,
188- prompt_embeds ,
189- do_classifier_free_guidance ,
190- guidance_scale_low ,
191- )
201+ bsz = latents .shape [0 ]
192202
193- def high_noise_branch (operands ):
194- latents , timestep , prompt_embeds = operands
195- return transformer_forward_pass (
196- high_noise_graphdef ,
197- high_noise_state ,
198- high_noise_rest ,
199- latents ,
200- timestep ,
201- prompt_embeds ,
202- do_classifier_free_guidance ,
203- guidance_scale_high ,
203+ # ── CFG cache path ──
204+ if use_cfg_cache and do_classifier_free_guidance :
205+ # Get timesteps as numpy for Python-level scheduling decisions
206+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
207+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
208+
209+ # Resolution-dependent CFG cache config — adapted for Wan 2.2.
210+ #
211+ # Key differences from Wan 2.1 (50 steps, single transformer):
212+ # 1. Fewer steps (30) → each step covers more denoising, so cache
213+ # drifts faster. interval=3 keeps max staleness at 2 steps
214+ # (~7% of total), matching Wan 2.1's ratio with interval=5.
215+ # 2. Low-noise transformer specialises in detail refinement, so
216+ # cond–uncond differences are more volatile. Lower α (0.1)
217+ # avoids overshooting the FFT correction.
218+ # 3. Phase weights: the boundary already encodes the structural→detail
219+ # transition. All low-noise (cache) steps use high-freq emphasis.
220+ if height >= 720 :
221+ cfg_cache_interval = 5
222+ cfg_cache_start_step = int (num_inference_steps / 3 )
223+ cfg_cache_end_step = int (num_inference_steps * 0.9 )
224+ cfg_cache_alpha = 0.2
225+ else :
226+ cfg_cache_interval = 5
227+ cfg_cache_start_step = int (num_inference_steps / 3 )
228+ cfg_cache_end_step = num_inference_steps - 1
229+ cfg_cache_alpha = 0.2
230+
231+ # Pre-split embeds once
232+ prompt_cond_embeds = prompt_embeds
233+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
234+
235+ # Determine the first low-noise step (boundary transition).
236+ # In Wan 2.2 the boundary IS the structural→detail transition, so
237+ # all low-noise cache steps should emphasise high-frequency correction.
238+ first_low_step = next (
239+ (s for s in range (num_inference_steps ) if not step_uses_high [s ]),
240+ num_inference_steps ,
204241 )
242+ t0_step = first_low_step # all cache steps get high-freq boost
243+
244+ # Pre-compute cache schedule and phase-dependent weights.
245+ first_full_in_low_seen = False
246+ step_is_cache = []
247+ step_w1w2 = []
248+ for s in range (num_inference_steps ):
249+ if step_uses_high [s ]:
250+ # Never cache high-noise transformer steps
251+ step_is_cache .append (False )
252+ else :
253+ is_cache = (
254+ first_full_in_low_seen
255+ and s >= cfg_cache_start_step
256+ and s < cfg_cache_end_step
257+ and (s - cfg_cache_start_step ) % cfg_cache_interval != 0
258+ )
259+ step_is_cache .append (is_cache )
260+ if not is_cache :
261+ first_full_in_low_seen = True
262+
263+ # Phase-dependent weights: w = 1 + α·I(condition)
264+ if s < t0_step :
265+ step_w1w2 .append ((1.0 + cfg_cache_alpha , 1.0 )) # high-noise: boost low-freq
266+ else :
267+ step_w1w2 .append ((1.0 , 1.0 + cfg_cache_alpha )) # low-noise: boost high-freq
268+
269+ # Cache tensors (on-device JAX arrays, initialised to None).
270+ cached_noise_cond = None
271+ cached_noise_uncond = None
272+
273+ for step in range (num_inference_steps ):
274+ t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
275+ is_cache_step = step_is_cache [step ]
276+
277+ # Select transformer and guidance scale based on precomputed schedule
278+ if step_uses_high [step ]:
279+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
280+ guidance_scale = guidance_scale_high
281+ else :
282+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
283+ guidance_scale = guidance_scale_low
284+
285+ if is_cache_step :
286+ # ── Cache step: cond-only forward + FFT frequency compensation ──
287+ w1 , w2 = step_w1w2 [step ]
288+ timestep = jnp .broadcast_to (t , bsz )
289+ noise_pred , cached_noise_cond = transformer_forward_pass_cfg_cache (
290+ graphdef ,
291+ state ,
292+ rest ,
293+ latents ,
294+ timestep ,
295+ prompt_cond_embeds ,
296+ cached_noise_cond ,
297+ cached_noise_uncond ,
298+ guidance_scale = guidance_scale ,
299+ w1 = jnp .float32 (w1 ),
300+ w2 = jnp .float32 (w2 ),
301+ )
302+ else :
303+ # ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
304+ latents_doubled = jnp .concatenate ([latents ] * 2 )
305+ timestep = jnp .broadcast_to (t , bsz * 2 )
306+ noise_pred , cached_noise_cond , cached_noise_uncond = transformer_forward_pass_full_cfg (
307+ graphdef ,
308+ state ,
309+ rest ,
310+ latents_doubled ,
311+ timestep ,
312+ prompt_embeds_combined ,
313+ guidance_scale = guidance_scale ,
314+ )
315+
316+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
317+ return latents
318+
319+ # ── Original non-cache path ──
320+ # Uses same Python-level if/else transformer selection as the cache path
321+ # so both paths compile to identical XLA graphs (critical for bfloat16
322+ # reproducibility in the PSNR comparison).
323+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
324+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
325+
326+ prompt_embeds_combined = (
327+ jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 ) if do_classifier_free_guidance else prompt_embeds
328+ )
205329
206330 for step in range (num_inference_steps ):
207331 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
208- if do_classifier_free_guidance :
209- latents = jnp .concatenate ([latents ] * 2 )
210- timestep = jnp .broadcast_to (t , latents .shape [0 ])
211332
212- use_high_noise = jnp .greater_equal (t , boundary )
333+ if step_uses_high [step ]:
334+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
335+ guidance_scale = guidance_scale_high
336+ else :
337+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
338+ guidance_scale = guidance_scale_low
213339
214- # Selects the model based on the current timestep:
215- # - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise).
216- # - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise).
217- noise_pred , latents = jax .lax .cond (
218- use_high_noise , high_noise_branch , low_noise_branch , (latents , timestep , prompt_embeds )
219- )
340+ if do_classifier_free_guidance :
341+ latents_doubled = jnp .concatenate ([latents ] * 2 )
342+ timestep = jnp .broadcast_to (t , bsz * 2 )
343+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
344+ graphdef ,
345+ state ,
346+ rest ,
347+ latents_doubled ,
348+ timestep ,
349+ prompt_embeds_combined ,
350+ guidance_scale = guidance_scale ,
351+ )
352+ else :
353+ timestep = jnp .broadcast_to (t , bsz )
354+ noise_pred , latents = transformer_forward_pass (
355+ graphdef ,
356+ state ,
357+ rest ,
358+ latents ,
359+ timestep ,
360+ prompt_embeds ,
361+ do_classifier_free_guidance ,
362+ guidance_scale ,
363+ )
220364
221365 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
222366 return latents
0 commit comments