Skip to content

Commit 63ffeb4

Browse files
author
James Huang
committed
cfg cache
Signed-off-by: James Huang <shyhuang@google.com>
1 parent cddbf6a commit 63ffeb4

4 files changed

Lines changed: 234 additions & 18 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ num_frames: 81
324324
guidance_scale: 5.0
325325
flow_shift: 3.0
326326

327+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
328+
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
329+
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
330+
use_cfg_cache: False
331+
327332
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
328333
guidance_rescale: 0.0
329334
num_inference_steps: 30

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
125125
num_frames=config.num_frames,
126126
num_inference_steps=config.num_inference_steps,
127127
guidance_scale=config.guidance_scale,
128+
use_cfg_cache=config.use_cfg_cache,
128129
)
129130
elif model_key == WAN2_2:
130131
return pipeline(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,108 @@ def transformer_forward_pass(
778778
latents = latents[:bsz]
779779

780780
return noise_pred, latents
781+
782+
783+
@partial(jax.jit, static_argnames=("guidance_scale",))
784+
def transformer_forward_pass_full_cfg(
785+
graphdef,
786+
sharded_state,
787+
rest_of_state,
788+
latents_doubled: jnp.array,
789+
timestep: jnp.array,
790+
prompt_embeds_combined: jnp.array,
791+
guidance_scale: float,
792+
encoder_hidden_states_image=None,
793+
):
794+
"""Full CFG forward pass.
795+
796+
Accepts pre-doubled latents and pre-concatenated [cond, uncond] prompt embeds.
797+
Returns the merged noise_pred plus raw noise_cond and noise_uncond for
798+
CFG cache storage. Keeping cond/uncond separate avoids a second forward
799+
pass on cache steps.
800+
"""
801+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
802+
bsz = latents_doubled.shape[0] // 2
803+
noise_pred = wan_transformer(
804+
hidden_states=latents_doubled,
805+
timestep=timestep,
806+
encoder_hidden_states=prompt_embeds_combined,
807+
encoder_hidden_states_image=encoder_hidden_states_image,
808+
)
809+
noise_cond = noise_pred[:bsz]
810+
noise_uncond = noise_pred[bsz:]
811+
noise_pred_merged = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
812+
return noise_pred_merged, noise_cond, noise_uncond
813+
814+
815+
@partial(jax.jit, static_argnames=("guidance_scale",))
816+
def transformer_forward_pass_cfg_cache(
817+
graphdef,
818+
sharded_state,
819+
rest_of_state,
820+
latents_cond: jnp.array,
821+
timestep_cond: jnp.array,
822+
prompt_cond_embeds: jnp.array,
823+
cached_noise_cond: jnp.array,
824+
cached_noise_uncond: jnp.array,
825+
guidance_scale: float,
826+
w1: float = 1.0,
827+
w2: float = 1.0,
828+
encoder_hidden_states_image=None,
829+
):
830+
"""CFG-Cache forward pass with FFT frequency-domain compensation.
831+
832+
FasterCache (Lv et al., ICLR 2025) CFG-Cache:
833+
1. Compute frequency-domain bias: ΔF = FFT(uncond) - FFT(cond)
834+
2. Split into low-freq (ΔLF) and high-freq (ΔHF) via spectral mask
835+
3. Apply phase-dependent weights:
836+
F_low = FFT(new_cond)_low + w1 * ΔLF
837+
F_high = FFT(new_cond)_high + w2 * ΔHF
838+
4. Reconstruct: uncond_approx = IFFT(F_low + F_high)
839+
840+
w1/w2 encode the denoising phase:
841+
Early (high noise): w1=1+α, w2=1 → boost low-freq correction
842+
Late (low noise): w1=1, w2=1+α → boost high-freq correction
843+
where α=0.2 (FasterCache default).
844+
845+
On TPU this compiles to a single static XLA graph with half the batch size
846+
of a full CFG pass.
847+
"""
848+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
849+
noise_cond = wan_transformer(
850+
hidden_states=latents_cond,
851+
timestep=timestep_cond,
852+
encoder_hidden_states=prompt_cond_embeds,
853+
encoder_hidden_states_image=encoder_hidden_states_image,
854+
)
855+
856+
# FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W]
857+
fft_cond_cached = jnp.fft.rfft2(cached_noise_cond.astype(jnp.float32))
858+
fft_uncond_cached = jnp.fft.rfft2(cached_noise_uncond.astype(jnp.float32))
859+
fft_bias = fft_uncond_cached - fft_cond_cached
860+
861+
# Build low/high frequency mask (25% cutoff)
862+
h = fft_bias.shape[-2]
863+
w_rfft = fft_bias.shape[-1]
864+
ch = jnp.maximum(1, h // 4)
865+
cw = jnp.maximum(1, w_rfft // 4)
866+
freq_h = jnp.arange(h)
867+
freq_w = jnp.arange(w_rfft)
868+
# Low-freq: indices near DC (0) in both dims; account for wrap-around in dim H
869+
low_h = (freq_h < ch) | (freq_h >= h - ch + 1)
870+
low_w = freq_w < cw
871+
low_mask = (low_h[:, None] & low_w[None, :]).astype(jnp.float32)
872+
high_mask = 1.0 - low_mask
873+
874+
# Apply phase-dependent weights to frequency bias
875+
fft_bias_weighted = fft_bias * (low_mask * w1 + high_mask * w2)
876+
877+
# Reconstruct unconditional output
878+
fft_cond_new = jnp.fft.rfft2(noise_cond.astype(jnp.float32))
879+
fft_uncond_approx = fft_cond_new + fft_bias_weighted
880+
noise_uncond_approx = jnp.fft.irfft2(
881+
fft_uncond_approx, s=noise_cond.shape[-2:]
882+
).astype(noise_cond.dtype)
883+
884+
noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx)
885+
return noise_pred_merged, noise_cond

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -90,6 +90,7 @@ def __call__(
9090
prompt_embeds: Optional[jax.Array] = None,
9191
negative_prompt_embeds: Optional[jax.Array] = None,
9292
vae_only: bool = False,
93+
use_cfg_cache: bool = False,
9394
):
9495
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
9596
prompt,
@@ -114,6 +115,8 @@ def __call__(
114115
num_inference_steps=num_inference_steps,
115116
scheduler=self.scheduler,
116117
scheduler_state=scheduler_state,
118+
use_cfg_cache=use_cfg_cache,
119+
height=height,
117120
)
118121

119122
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -140,26 +143,128 @@ def run_inference_2_1(
140143
num_inference_steps: int,
141144
scheduler: FlaxUniPCMultistepScheduler,
142145
scheduler_state,
146+
use_cfg_cache: bool = False,
147+
height: int = 480,
143148
):
144-
do_classifier_free_guidance = guidance_scale > 1.0
145-
if do_classifier_free_guidance:
146-
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
149+
"""Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
150+
151+
CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
152+
- Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
153+
Cache raw noise_cond and noise_uncond for FFT bias.
154+
- Cache steps : run transformer on cond batch only (batch×1).
155+
Estimate uncond via FFT frequency-domain compensation:
156+
ΔF = FFT(cached_uncond) - FFT(cached_cond)
157+
Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
158+
uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
159+
Phase-dependent weights (α=0.2):
160+
Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
161+
Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
162+
- Schedule : full CFG for the first 1/3 of steps, then
163+
full CFG every 5 steps, cache the rest.
164+
165+
Two separately-compiled JAX-jitted functions handle full and cache steps so
166+
XLA sees static shapes throughout — the key requirement for TPU efficiency.
167+
"""
168+
do_cfg = guidance_scale > 1.0
169+
bsz = latents.shape[0]
170+
171+
# Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
172+
if height >= 720:
173+
# 720p: conservative — protect last 40%, interval=5
174+
cfg_cache_interval = 5
175+
cfg_cache_start_step = int(num_inference_steps / 3)
176+
cfg_cache_end_step = int(num_inference_steps * 0.9)
177+
cfg_cache_alpha = 0.2
178+
else:
179+
# 480p: moderate — protect last 2 steps, interval=5
180+
cfg_cache_interval = 5
181+
cfg_cache_start_step = int(num_inference_steps / 3)
182+
cfg_cache_end_step = num_inference_steps - 2
183+
cfg_cache_alpha = 0.2
184+
185+
# Pre-split embeds once, outside the loop.
186+
prompt_cond_embeds = prompt_embeds
187+
prompt_embeds_combined = None
188+
if do_cfg:
189+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
190+
191+
# Pre-compute cache schedule and phase-dependent weights.
192+
# t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
193+
t0_step = num_inference_steps // 2
194+
first_full_step_seen = False
195+
step_is_cache = []
196+
step_w1w2 = []
197+
for s in range(num_inference_steps):
198+
is_cache = (
199+
use_cfg_cache
200+
and do_cfg
201+
and first_full_step_seen
202+
and s >= cfg_cache_start_step
203+
and s < cfg_cache_end_step
204+
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
205+
)
206+
step_is_cache.append(is_cache)
207+
if not is_cache:
208+
first_full_step_seen = True
209+
# Phase-dependent weights: w = 1 + α·I(condition)
210+
if s < t0_step:
211+
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # early: boost low-freq
212+
else:
213+
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # late: boost high-freq
214+
215+
# Cache tensors (on-device JAX arrays, initialised to None).
216+
cached_noise_cond = None
217+
cached_noise_uncond = None
218+
147219
for step in range(num_inference_steps):
148220
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
149-
if do_classifier_free_guidance:
150-
latents = jnp.concatenate([latents] * 2)
151-
timestep = jnp.broadcast_to(t, latents.shape[0])
152-
153-
noise_pred, latents = transformer_forward_pass(
154-
graphdef,
155-
sharded_state,
156-
rest_of_state,
157-
latents,
158-
timestep,
159-
prompt_embeds,
160-
do_classifier_free_guidance=do_classifier_free_guidance,
161-
guidance_scale=guidance_scale,
162-
)
221+
is_cache_step = step_is_cache[step]
222+
223+
if is_cache_step:
224+
# ── Cache step: cond-only forward + FFT frequency compensation ──
225+
w1, w2 = step_w1w2[step]
226+
timestep = jnp.broadcast_to(t, bsz)
227+
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
228+
graphdef,
229+
sharded_state,
230+
rest_of_state,
231+
latents,
232+
timestep,
233+
prompt_cond_embeds,
234+
cached_noise_cond,
235+
cached_noise_uncond,
236+
guidance_scale=guidance_scale,
237+
w1=jnp.float32(w1),
238+
w2=jnp.float32(w2),
239+
)
240+
241+
elif do_cfg:
242+
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
243+
latents_doubled = jnp.concatenate([latents] * 2)
244+
timestep = jnp.broadcast_to(t, bsz * 2)
245+
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
246+
graphdef,
247+
sharded_state,
248+
rest_of_state,
249+
latents_doubled,
250+
timestep,
251+
prompt_embeds_combined,
252+
guidance_scale=guidance_scale,
253+
)
254+
255+
else:
256+
# ── No CFG (guidance_scale <= 1.0) ──
257+
timestep = jnp.broadcast_to(t, bsz)
258+
noise_pred, latents = transformer_forward_pass(
259+
graphdef,
260+
sharded_state,
261+
rest_of_state,
262+
latents,
263+
timestep,
264+
prompt_cond_embeds,
265+
do_classifier_free_guidance=False,
266+
guidance_scale=guidance_scale,
267+
)
163268

164269
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
165270
return latents

0 commit comments

Comments
 (0)