Skip to content

Commit e04e78d

Browse files
James Huangeltsai
authored andcommitted
cfg cache
Signed-off-by: James Huang <shyhuang@google.com>
1 parent 115fffa commit e04e78d

4 files changed

Lines changed: 234 additions & 18 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ num_frames: 81
323323
guidance_scale: 5.0
324324
flow_shift: 3.0
325325

326+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
327+
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
328+
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
329+
use_cfg_cache: False
330+
326331
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
327332
guidance_rescale: 0.0
328333
num_inference_steps: 30

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
125125
num_frames=config.num_frames,
126126
num_inference_steps=config.num_inference_steps,
127127
guidance_scale=config.guidance_scale,
128+
use_cfg_cache=config.use_cfg_cache,
128129
)
129130
elif model_key == WAN2_2:
130131
return pipeline(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,108 @@ def transformer_forward_pass(
812812
latents = latents[:bsz]
813813

814814
return noise_pred, latents
815+
816+
817+
@partial(jax.jit, static_argnames=("guidance_scale",))
818+
def transformer_forward_pass_full_cfg(
819+
graphdef,
820+
sharded_state,
821+
rest_of_state,
822+
latents_doubled: jnp.array,
823+
timestep: jnp.array,
824+
prompt_embeds_combined: jnp.array,
825+
guidance_scale: float,
826+
encoder_hidden_states_image=None,
827+
):
828+
"""Full CFG forward pass.
829+
830+
Accepts pre-doubled latents and pre-concatenated [cond, uncond] prompt embeds.
831+
Returns the merged noise_pred plus raw noise_cond and noise_uncond for
832+
CFG cache storage. Keeping cond/uncond separate avoids a second forward
833+
pass on cache steps.
834+
"""
835+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
836+
bsz = latents_doubled.shape[0] // 2
837+
noise_pred = wan_transformer(
838+
hidden_states=latents_doubled,
839+
timestep=timestep,
840+
encoder_hidden_states=prompt_embeds_combined,
841+
encoder_hidden_states_image=encoder_hidden_states_image,
842+
)
843+
noise_cond = noise_pred[:bsz]
844+
noise_uncond = noise_pred[bsz:]
845+
noise_pred_merged = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
846+
return noise_pred_merged, noise_cond, noise_uncond
847+
848+
849+
@partial(jax.jit, static_argnames=("guidance_scale",))
850+
def transformer_forward_pass_cfg_cache(
851+
graphdef,
852+
sharded_state,
853+
rest_of_state,
854+
latents_cond: jnp.array,
855+
timestep_cond: jnp.array,
856+
prompt_cond_embeds: jnp.array,
857+
cached_noise_cond: jnp.array,
858+
cached_noise_uncond: jnp.array,
859+
guidance_scale: float,
860+
w1: float = 1.0,
861+
w2: float = 1.0,
862+
encoder_hidden_states_image=None,
863+
):
864+
"""CFG-Cache forward pass with FFT frequency-domain compensation.
865+
866+
FasterCache (Lv et al., ICLR 2025) CFG-Cache:
867+
1. Compute frequency-domain bias: ΔF = FFT(uncond) - FFT(cond)
868+
2. Split into low-freq (ΔLF) and high-freq (ΔHF) via spectral mask
869+
3. Apply phase-dependent weights:
870+
F_low = FFT(new_cond)_low + w1 * ΔLF
871+
F_high = FFT(new_cond)_high + w2 * ΔHF
872+
4. Reconstruct: uncond_approx = IFFT(F_low + F_high)
873+
874+
w1/w2 encode the denoising phase:
875+
Early (high noise): w1=1+α, w2=1 → boost low-freq correction
876+
Late (low noise): w1=1, w2=1+α → boost high-freq correction
877+
where α=0.2 (FasterCache default).
878+
879+
On TPU this compiles to a single static XLA graph with half the batch size
880+
of a full CFG pass.
881+
"""
882+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
883+
noise_cond = wan_transformer(
884+
hidden_states=latents_cond,
885+
timestep=timestep_cond,
886+
encoder_hidden_states=prompt_cond_embeds,
887+
encoder_hidden_states_image=encoder_hidden_states_image,
888+
)
889+
890+
# FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W]
891+
fft_cond_cached = jnp.fft.rfft2(cached_noise_cond.astype(jnp.float32))
892+
fft_uncond_cached = jnp.fft.rfft2(cached_noise_uncond.astype(jnp.float32))
893+
fft_bias = fft_uncond_cached - fft_cond_cached
894+
895+
# Build low/high frequency mask (25% cutoff)
896+
h = fft_bias.shape[-2]
897+
w_rfft = fft_bias.shape[-1]
898+
ch = jnp.maximum(1, h // 4)
899+
cw = jnp.maximum(1, w_rfft // 4)
900+
freq_h = jnp.arange(h)
901+
freq_w = jnp.arange(w_rfft)
902+
# Low-freq: indices near DC (0) in both dims; account for wrap-around in dim H
903+
low_h = (freq_h < ch) | (freq_h >= h - ch + 1)
904+
low_w = freq_w < cw
905+
low_mask = (low_h[:, None] & low_w[None, :]).astype(jnp.float32)
906+
high_mask = 1.0 - low_mask
907+
908+
# Apply phase-dependent weights to frequency bias
909+
fft_bias_weighted = fft_bias * (low_mask * w1 + high_mask * w2)
910+
911+
# Reconstruct unconditional output
912+
fft_cond_new = jnp.fft.rfft2(noise_cond.astype(jnp.float32))
913+
fft_uncond_approx = fft_cond_new + fft_bias_weighted
914+
noise_uncond_approx = jnp.fft.irfft2(
915+
fft_uncond_approx, s=noise_cond.shape[-2:]
916+
).astype(noise_cond.dtype)
917+
918+
noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx)
919+
return noise_pred_merged, noise_cond

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -91,6 +91,7 @@ def __call__(
9191
prompt_embeds: Optional[jax.Array] = None,
9292
negative_prompt_embeds: Optional[jax.Array] = None,
9393
vae_only: bool = False,
94+
use_cfg_cache: bool = False,
9495
):
9596
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
9697
prompt,
@@ -115,6 +116,8 @@ def __call__(
115116
num_inference_steps=num_inference_steps,
116117
scheduler=self.scheduler,
117118
scheduler_state=scheduler_state,
119+
use_cfg_cache=use_cfg_cache,
120+
height=height,
118121
)
119122

120123
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -141,26 +144,128 @@ def run_inference_2_1(
141144
num_inference_steps: int,
142145
scheduler: FlaxUniPCMultistepScheduler,
143146
scheduler_state,
147+
use_cfg_cache: bool = False,
148+
height: int = 480,
144149
):
145-
do_classifier_free_guidance = guidance_scale > 1.0
146-
if do_classifier_free_guidance:
147-
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
150+
"""Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
151+
152+
CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
153+
- Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
154+
Cache raw noise_cond and noise_uncond for FFT bias.
155+
- Cache steps : run transformer on cond batch only (batch×1).
156+
Estimate uncond via FFT frequency-domain compensation:
157+
ΔF = FFT(cached_uncond) - FFT(cached_cond)
158+
Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
159+
uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
160+
Phase-dependent weights (α=0.2):
161+
Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
162+
Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
163+
- Schedule : full CFG for the first 1/3 of steps, then
164+
full CFG every 5 steps, cache the rest.
165+
166+
Two separately-compiled JAX-jitted functions handle full and cache steps so
167+
XLA sees static shapes throughout — the key requirement for TPU efficiency.
168+
"""
169+
do_cfg = guidance_scale > 1.0
170+
bsz = latents.shape[0]
171+
172+
# Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
173+
if height >= 720:
174+
# 720p: conservative — protect last 40%, interval=5
175+
cfg_cache_interval = 5
176+
cfg_cache_start_step = int(num_inference_steps / 3)
177+
cfg_cache_end_step = int(num_inference_steps * 0.9)
178+
cfg_cache_alpha = 0.2
179+
else:
180+
# 480p: moderate — protect last 2 steps, interval=5
181+
cfg_cache_interval = 5
182+
cfg_cache_start_step = int(num_inference_steps / 3)
183+
cfg_cache_end_step = num_inference_steps - 2
184+
cfg_cache_alpha = 0.2
185+
186+
# Pre-split embeds once, outside the loop.
187+
prompt_cond_embeds = prompt_embeds
188+
prompt_embeds_combined = None
189+
if do_cfg:
190+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
191+
192+
# Pre-compute cache schedule and phase-dependent weights.
193+
# t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
194+
t0_step = num_inference_steps // 2
195+
first_full_step_seen = False
196+
step_is_cache = []
197+
step_w1w2 = []
198+
for s in range(num_inference_steps):
199+
is_cache = (
200+
use_cfg_cache
201+
and do_cfg
202+
and first_full_step_seen
203+
and s >= cfg_cache_start_step
204+
and s < cfg_cache_end_step
205+
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
206+
)
207+
step_is_cache.append(is_cache)
208+
if not is_cache:
209+
first_full_step_seen = True
210+
# Phase-dependent weights: w = 1 + α·I(condition)
211+
if s < t0_step:
212+
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # early: boost low-freq
213+
else:
214+
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # late: boost high-freq
215+
216+
# Cache tensors (on-device JAX arrays, initialised to None).
217+
cached_noise_cond = None
218+
cached_noise_uncond = None
219+
148220
for step in range(num_inference_steps):
149221
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
150-
if do_classifier_free_guidance:
151-
latents = jnp.concatenate([latents] * 2)
152-
timestep = jnp.broadcast_to(t, latents.shape[0])
153-
154-
noise_pred, latents = transformer_forward_pass(
155-
graphdef,
156-
sharded_state,
157-
rest_of_state,
158-
latents,
159-
timestep,
160-
prompt_embeds,
161-
do_classifier_free_guidance=do_classifier_free_guidance,
162-
guidance_scale=guidance_scale,
163-
)
222+
is_cache_step = step_is_cache[step]
223+
224+
if is_cache_step:
225+
# ── Cache step: cond-only forward + FFT frequency compensation ──
226+
w1, w2 = step_w1w2[step]
227+
timestep = jnp.broadcast_to(t, bsz)
228+
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
229+
graphdef,
230+
sharded_state,
231+
rest_of_state,
232+
latents,
233+
timestep,
234+
prompt_cond_embeds,
235+
cached_noise_cond,
236+
cached_noise_uncond,
237+
guidance_scale=guidance_scale,
238+
w1=jnp.float32(w1),
239+
w2=jnp.float32(w2),
240+
)
241+
242+
elif do_cfg:
243+
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
244+
latents_doubled = jnp.concatenate([latents] * 2)
245+
timestep = jnp.broadcast_to(t, bsz * 2)
246+
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
247+
graphdef,
248+
sharded_state,
249+
rest_of_state,
250+
latents_doubled,
251+
timestep,
252+
prompt_embeds_combined,
253+
guidance_scale=guidance_scale,
254+
)
255+
256+
else:
257+
# ── No CFG (guidance_scale <= 1.0) ──
258+
timestep = jnp.broadcast_to(t, bsz)
259+
noise_pred, latents = transformer_forward_pass(
260+
graphdef,
261+
sharded_state,
262+
rest_of_state,
263+
latents,
264+
timestep,
265+
prompt_cond_embeds,
266+
do_classifier_free_guidance=False,
267+
guidance_scale=guidance_scale,
268+
)
164269

165270
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
166271
return latents

0 commit comments

Comments
 (0)