Skip to content

Commit a4bf967

Browse files
author
James Huang
committed
CFG Cache For Wan 2.2
Signed-off-by: James Huang <shyhuang@google.com>
1 parent 4085595 commit a4bf967

2 files changed

Lines changed: 469 additions & 38 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 182 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -21,6 +21,7 @@
2121
from flax.linen import partitioning as nn_partitioning
2222
import jax
2323
import jax.numpy as jnp
24+
import numpy as np
2425
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2526

2627

@@ -32,7 +33,7 @@ def __init__(
3233
config: HyperParameters,
3334
low_noise_transformer: Optional[WanModel],
3435
high_noise_transformer: Optional[WanModel],
35-
**kwargs
36+
**kwargs,
3637
):
3738
super().__init__(config=config, **kwargs)
3839
self.low_noise_transformer = low_noise_transformer
@@ -109,7 +110,15 @@ def __call__(
109110
prompt_embeds: jax.Array = None,
110111
negative_prompt_embeds: jax.Array = None,
111112
vae_only: bool = False,
113+
use_cfg_cache: bool = False,
112114
):
115+
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
116+
raise ValueError(
117+
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118+
f"(got {guidance_scale_low}, {guidance_scale_high}). "
119+
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120+
)
121+
113122
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
114123
prompt,
115124
negative_prompt,
@@ -138,6 +147,8 @@ def __call__(
138147
num_inference_steps=num_inference_steps,
139148
scheduler=self.scheduler,
140149
scheduler_state=scheduler_state,
150+
use_cfg_cache=use_cfg_cache,
151+
height=height,
141152
)
142153

143154
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -172,51 +183,184 @@ def run_inference_2_2(
172183
num_inference_steps: int,
173184
scheduler: FlaxUniPCMultistepScheduler,
174185
scheduler_state,
186+
use_cfg_cache: bool = False,
187+
height: int = 480,
175188
):
189+
"""Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190+
191+
Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192+
- High-noise phase (t >= boundary): always full CFG — short phase, critical
193+
for establishing video structure.
194+
- Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195+
steps, FFT frequency-domain compensation on cache steps (batch×1).
196+
- Boundary transition: mandatory full CFG step to populate cache for the
197+
low-noise transformer.
198+
- FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
199+
"""
176200
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
177-
if do_classifier_free_guidance:
178-
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
179-
180-
def low_noise_branch(operands):
181-
latents, timestep, prompt_embeds = operands
182-
return transformer_forward_pass(
183-
low_noise_graphdef,
184-
low_noise_state,
185-
low_noise_rest,
186-
latents,
187-
timestep,
188-
prompt_embeds,
189-
do_classifier_free_guidance,
190-
guidance_scale_low,
191-
)
201+
bsz = latents.shape[0]
192202

193-
def high_noise_branch(operands):
194-
latents, timestep, prompt_embeds = operands
195-
return transformer_forward_pass(
196-
high_noise_graphdef,
197-
high_noise_state,
198-
high_noise_rest,
199-
latents,
200-
timestep,
201-
prompt_embeds,
202-
do_classifier_free_guidance,
203-
guidance_scale_high,
203+
# ── CFG cache path ──
204+
if use_cfg_cache and do_classifier_free_guidance:
205+
# Get timesteps as numpy for Python-level scheduling decisions
206+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
207+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
208+
209+
# Resolution-dependent CFG cache config — adapted for Wan 2.2.
210+
#
211+
# Key differences from Wan 2.1 (50 steps, single transformer):
212+
# 1. Fewer steps (30) → each step covers more denoising, so cache
213+
# drifts faster. interval=3 keeps max staleness at 2 steps
214+
# (~7% of total), matching Wan 2.1's ratio with interval=5.
215+
# 2. Low-noise transformer specialises in detail refinement, so
216+
# cond–uncond differences are more volatile. Lower α (0.1)
217+
# avoids overshooting the FFT correction.
218+
# 3. Phase weights: the boundary already encodes the structural→detail
219+
# transition. All low-noise (cache) steps use high-freq emphasis.
220+
if height >= 720:
221+
cfg_cache_interval = 5
222+
cfg_cache_start_step = int(num_inference_steps / 3)
223+
cfg_cache_end_step = int(num_inference_steps * 0.9)
224+
cfg_cache_alpha = 0.2
225+
else:
226+
cfg_cache_interval = 5
227+
cfg_cache_start_step = int(num_inference_steps / 3)
228+
cfg_cache_end_step = num_inference_steps - 1
229+
cfg_cache_alpha = 0.2
230+
231+
# Pre-split embeds once
232+
prompt_cond_embeds = prompt_embeds
233+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
234+
235+
# Determine the first low-noise step (boundary transition).
236+
# In Wan 2.2 the boundary IS the structural→detail transition, so
237+
# all low-noise cache steps should emphasise high-frequency correction.
238+
first_low_step = next(
239+
(s for s in range(num_inference_steps) if not step_uses_high[s]),
240+
num_inference_steps,
204241
)
242+
t0_step = first_low_step # all cache steps get high-freq boost
243+
244+
# Pre-compute cache schedule and phase-dependent weights.
245+
first_full_in_low_seen = False
246+
step_is_cache = []
247+
step_w1w2 = []
248+
for s in range(num_inference_steps):
249+
if step_uses_high[s]:
250+
# Never cache high-noise transformer steps
251+
step_is_cache.append(False)
252+
else:
253+
is_cache = (
254+
first_full_in_low_seen
255+
and s >= cfg_cache_start_step
256+
and s < cfg_cache_end_step
257+
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
258+
)
259+
step_is_cache.append(is_cache)
260+
if not is_cache:
261+
first_full_in_low_seen = True
262+
263+
# Phase-dependent weights: w = 1 + α·I(condition)
264+
if s < t0_step:
265+
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq
266+
else:
267+
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq
268+
269+
# Cache tensors (on-device JAX arrays, initialised to None).
270+
cached_noise_cond = None
271+
cached_noise_uncond = None
272+
273+
for step in range(num_inference_steps):
274+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
275+
is_cache_step = step_is_cache[step]
276+
277+
# Select transformer and guidance scale based on precomputed schedule
278+
if step_uses_high[step]:
279+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
280+
guidance_scale = guidance_scale_high
281+
else:
282+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
283+
guidance_scale = guidance_scale_low
284+
285+
if is_cache_step:
286+
# ── Cache step: cond-only forward + FFT frequency compensation ──
287+
w1, w2 = step_w1w2[step]
288+
timestep = jnp.broadcast_to(t, bsz)
289+
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
290+
graphdef,
291+
state,
292+
rest,
293+
latents,
294+
timestep,
295+
prompt_cond_embeds,
296+
cached_noise_cond,
297+
cached_noise_uncond,
298+
guidance_scale=guidance_scale,
299+
w1=jnp.float32(w1),
300+
w2=jnp.float32(w2),
301+
)
302+
else:
303+
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
304+
latents_doubled = jnp.concatenate([latents] * 2)
305+
timestep = jnp.broadcast_to(t, bsz * 2)
306+
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
307+
graphdef,
308+
state,
309+
rest,
310+
latents_doubled,
311+
timestep,
312+
prompt_embeds_combined,
313+
guidance_scale=guidance_scale,
314+
)
315+
316+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
317+
return latents
318+
319+
# ── Original non-cache path ──
320+
# Uses same Python-level if/else transformer selection as the cache path
321+
# so both paths compile to identical XLA graphs (critical for bfloat16
322+
# reproducibility in the PSNR comparison).
323+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
324+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
325+
326+
prompt_embeds_combined = (
327+
jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds
328+
)
205329

206330
for step in range(num_inference_steps):
207331
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
208-
if do_classifier_free_guidance:
209-
latents = jnp.concatenate([latents] * 2)
210-
timestep = jnp.broadcast_to(t, latents.shape[0])
211332

212-
use_high_noise = jnp.greater_equal(t, boundary)
333+
if step_uses_high[step]:
334+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
335+
guidance_scale = guidance_scale_high
336+
else:
337+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
338+
guidance_scale = guidance_scale_low
213339

214-
# Selects the model based on the current timestep:
215-
# - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise).
216-
# - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise).
217-
noise_pred, latents = jax.lax.cond(
218-
use_high_noise, high_noise_branch, low_noise_branch, (latents, timestep, prompt_embeds)
219-
)
340+
if do_classifier_free_guidance:
341+
latents_doubled = jnp.concatenate([latents] * 2)
342+
timestep = jnp.broadcast_to(t, bsz * 2)
343+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
344+
graphdef,
345+
state,
346+
rest,
347+
latents_doubled,
348+
timestep,
349+
prompt_embeds_combined,
350+
guidance_scale=guidance_scale,
351+
)
352+
else:
353+
timestep = jnp.broadcast_to(t, bsz)
354+
noise_pred, latents = transformer_forward_pass(
355+
graphdef,
356+
state,
357+
rest,
358+
latents,
359+
timestep,
360+
prompt_embeds,
361+
do_classifier_free_guidance,
362+
guidance_scale,
363+
)
220364

221365
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
222366
return latents

0 commit comments

Comments
 (0)