1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from .wan_pipeline import WanPipeline , transformer_forward_pass
15+ from .wan_pipeline import WanPipeline , transformer_forward_pass , transformer_forward_pass_full_cfg , transformer_forward_pass_cfg_cache
1616from ...models .wan .transformers .transformer_wan import WanModel
1717from typing import List , Union , Optional
1818from ...pyconfig import HyperParameters
@@ -90,6 +90,7 @@ def __call__(
9090 prompt_embeds : Optional [jax .Array ] = None ,
9191 negative_prompt_embeds : Optional [jax .Array ] = None ,
9292 vae_only : bool = False ,
93+ use_cfg_cache : bool = False ,
9394 ):
9495 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
9596 prompt ,
@@ -114,6 +115,8 @@ def __call__(
114115 num_inference_steps = num_inference_steps ,
115116 scheduler = self .scheduler ,
116117 scheduler_state = scheduler_state ,
118+ use_cfg_cache = use_cfg_cache ,
119+ height = height ,
117120 )
118121
119122 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -140,26 +143,128 @@ def run_inference_2_1(
140143 num_inference_steps : int ,
141144 scheduler : FlaxUniPCMultistepScheduler ,
142145 scheduler_state ,
146+ use_cfg_cache : bool = False ,
147+ height : int = 480 ,
143148):
144- do_classifier_free_guidance = guidance_scale > 1.0
145- if do_classifier_free_guidance :
146- prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
149+ """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
150+
151+ CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
152+ - Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
153+ Cache raw noise_cond and noise_uncond for FFT bias.
154+ - Cache steps : run transformer on cond batch only (batch×1).
155+ Estimate uncond via FFT frequency-domain compensation:
156+ ΔF = FFT(cached_uncond) - FFT(cached_cond)
157+ Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
158+ uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
159+ Phase-dependent weights (α=0.2):
160+ Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
161+ Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
162+ - Schedule : full CFG for the first 1/3 of steps, then
163+ full CFG every 5 steps, cache the rest.
164+
165+ Two separately-compiled JAX-jitted functions handle full and cache steps so
166+ XLA sees static shapes throughout — the key requirement for TPU efficiency.
167+ """
168+ do_cfg = guidance_scale > 1.0
169+ bsz = latents .shape [0 ]
170+
171+ # Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
172+ if height >= 720 :
173+ # 720p: conservative — protect last 40%, interval=5
174+ cfg_cache_interval = 5
175+ cfg_cache_start_step = int (num_inference_steps / 3 )
176+ cfg_cache_end_step = int (num_inference_steps * 0.9 )
177+ cfg_cache_alpha = 0.2
178+ else :
179+ # 480p: moderate — protect last 2 steps, interval=5
180+ cfg_cache_interval = 5
181+ cfg_cache_start_step = int (num_inference_steps / 3 )
182+ cfg_cache_end_step = num_inference_steps - 2
183+ cfg_cache_alpha = 0.2
184+
185+ # Pre-split embeds once, outside the loop.
186+ prompt_cond_embeds = prompt_embeds
187+ prompt_embeds_combined = None
188+ if do_cfg :
189+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
190+
191+ # Pre-compute cache schedule and phase-dependent weights.
192+ # t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
193+ t0_step = num_inference_steps // 2
194+ first_full_step_seen = False
195+ step_is_cache = []
196+ step_w1w2 = []
197+ for s in range (num_inference_steps ):
198+ is_cache = (
199+ use_cfg_cache
200+ and do_cfg
201+ and first_full_step_seen
202+ and s >= cfg_cache_start_step
203+ and s < cfg_cache_end_step
204+ and (s - cfg_cache_start_step ) % cfg_cache_interval != 0
205+ )
206+ step_is_cache .append (is_cache )
207+ if not is_cache :
208+ first_full_step_seen = True
209+ # Phase-dependent weights: w = 1 + α·I(condition)
210+ if s < t0_step :
211+ step_w1w2 .append ((1.0 + cfg_cache_alpha , 1.0 )) # early: boost low-freq
212+ else :
213+ step_w1w2 .append ((1.0 , 1.0 + cfg_cache_alpha )) # late: boost high-freq
214+
215+ # Cache tensors (on-device JAX arrays, initialised to None).
216+ cached_noise_cond = None
217+ cached_noise_uncond = None
218+
147219 for step in range (num_inference_steps ):
148220 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
149- if do_classifier_free_guidance :
150- latents = jnp .concatenate ([latents ] * 2 )
151- timestep = jnp .broadcast_to (t , latents .shape [0 ])
152-
153- noise_pred , latents = transformer_forward_pass (
154- graphdef ,
155- sharded_state ,
156- rest_of_state ,
157- latents ,
158- timestep ,
159- prompt_embeds ,
160- do_classifier_free_guidance = do_classifier_free_guidance ,
161- guidance_scale = guidance_scale ,
162- )
221+ is_cache_step = step_is_cache [step ]
222+
223+ if is_cache_step :
224+ # ── Cache step: cond-only forward + FFT frequency compensation ──
225+ w1 , w2 = step_w1w2 [step ]
226+ timestep = jnp .broadcast_to (t , bsz )
227+ noise_pred , cached_noise_cond = transformer_forward_pass_cfg_cache (
228+ graphdef ,
229+ sharded_state ,
230+ rest_of_state ,
231+ latents ,
232+ timestep ,
233+ prompt_cond_embeds ,
234+ cached_noise_cond ,
235+ cached_noise_uncond ,
236+ guidance_scale = guidance_scale ,
237+ w1 = jnp .float32 (w1 ),
238+ w2 = jnp .float32 (w2 ),
239+ )
240+
241+ elif do_cfg :
242+ # ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
243+ latents_doubled = jnp .concatenate ([latents ] * 2 )
244+ timestep = jnp .broadcast_to (t , bsz * 2 )
245+ noise_pred , cached_noise_cond , cached_noise_uncond = transformer_forward_pass_full_cfg (
246+ graphdef ,
247+ sharded_state ,
248+ rest_of_state ,
249+ latents_doubled ,
250+ timestep ,
251+ prompt_embeds_combined ,
252+ guidance_scale = guidance_scale ,
253+ )
254+
255+ else :
256+ # ── No CFG (guidance_scale <= 1.0) ──
257+ timestep = jnp .broadcast_to (t , bsz )
258+ noise_pred , latents = transformer_forward_pass (
259+ graphdef ,
260+ sharded_state ,
261+ rest_of_state ,
262+ latents ,
263+ timestep ,
264+ prompt_cond_embeds ,
265+ do_classifier_free_guidance = False ,
266+ guidance_scale = guidance_scale ,
267+ )
163268
164269 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
165270 return latents
0 commit comments