1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from .wan_pipeline import WanPipeline , transformer_forward_pass
15+ from .wan_pipeline import WanPipeline , transformer_forward_pass , transformer_forward_pass_full_cfg , transformer_forward_pass_cfg_cache
1616from ...models .wan .transformers .transformer_wan import WanModel
1717from typing import List , Union , Optional
1818from ...pyconfig import HyperParameters
@@ -91,6 +91,7 @@ def __call__(
9191 prompt_embeds : Optional [jax .Array ] = None ,
9292 negative_prompt_embeds : Optional [jax .Array ] = None ,
9393 vae_only : bool = False ,
94+ use_cfg_cache : bool = False ,
9495 ):
9596 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
9697 prompt ,
@@ -115,6 +116,8 @@ def __call__(
115116 num_inference_steps = num_inference_steps ,
116117 scheduler = self .scheduler ,
117118 scheduler_state = scheduler_state ,
119+ use_cfg_cache = use_cfg_cache ,
120+ height = height ,
118121 )
119122
120123 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -141,26 +144,128 @@ def run_inference_2_1(
141144 num_inference_steps : int ,
142145 scheduler : FlaxUniPCMultistepScheduler ,
143146 scheduler_state ,
147+ use_cfg_cache : bool = False ,
148+ height : int = 480 ,
144149):
145- do_classifier_free_guidance = guidance_scale > 1.0
146- if do_classifier_free_guidance :
147- prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
150+ """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
151+
152+ CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
153+ - Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
154+ Cache raw noise_cond and noise_uncond for FFT bias.
155+ - Cache steps : run transformer on cond batch only (batch×1).
156+ Estimate uncond via FFT frequency-domain compensation:
157+ ΔF = FFT(cached_uncond) - FFT(cached_cond)
158+ Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
159+ uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
160+ Phase-dependent weights (α=0.2):
161+ Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
162+ Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
163+ - Schedule : full CFG for the first 1/3 of steps, then
164+ full CFG every 5 steps, cache the rest.
165+
166+ Two separately-compiled JAX-jitted functions handle full and cache steps so
167+ XLA sees static shapes throughout — the key requirement for TPU efficiency.
168+ """
169+ do_cfg = guidance_scale > 1.0
170+ bsz = latents .shape [0 ]
171+
172+ # Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
173+ if height >= 720 :
174+ # 720p: conservative — protect last 40%, interval=5
175+ cfg_cache_interval = 5
176+ cfg_cache_start_step = int (num_inference_steps / 3 )
177+ cfg_cache_end_step = int (num_inference_steps * 0.9 )
178+ cfg_cache_alpha = 0.2
179+ else :
180+ # 480p: moderate — protect last 2 steps, interval=5
181+ cfg_cache_interval = 5
182+ cfg_cache_start_step = int (num_inference_steps / 3 )
183+ cfg_cache_end_step = num_inference_steps - 2
184+ cfg_cache_alpha = 0.2
185+
186+ # Pre-split embeds once, outside the loop.
187+ prompt_cond_embeds = prompt_embeds
188+ prompt_embeds_combined = None
189+ if do_cfg :
190+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
191+
192+ # Pre-compute cache schedule and phase-dependent weights.
193+ # t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
194+ t0_step = num_inference_steps // 2
195+ first_full_step_seen = False
196+ step_is_cache = []
197+ step_w1w2 = []
198+ for s in range (num_inference_steps ):
199+ is_cache = (
200+ use_cfg_cache
201+ and do_cfg
202+ and first_full_step_seen
203+ and s >= cfg_cache_start_step
204+ and s < cfg_cache_end_step
205+ and (s - cfg_cache_start_step ) % cfg_cache_interval != 0
206+ )
207+ step_is_cache .append (is_cache )
208+ if not is_cache :
209+ first_full_step_seen = True
210+ # Phase-dependent weights: w = 1 + α·I(condition)
211+ if s < t0_step :
212+ step_w1w2 .append ((1.0 + cfg_cache_alpha , 1.0 )) # early: boost low-freq
213+ else :
214+ step_w1w2 .append ((1.0 , 1.0 + cfg_cache_alpha )) # late: boost high-freq
215+
216+ # Cache tensors (on-device JAX arrays, initialised to None).
217+ cached_noise_cond = None
218+ cached_noise_uncond = None
219+
148220 for step in range (num_inference_steps ):
149221 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
150- if do_classifier_free_guidance :
151- latents = jnp .concatenate ([latents ] * 2 )
152- timestep = jnp .broadcast_to (t , latents .shape [0 ])
153-
154- noise_pred , latents = transformer_forward_pass (
155- graphdef ,
156- sharded_state ,
157- rest_of_state ,
158- latents ,
159- timestep ,
160- prompt_embeds ,
161- do_classifier_free_guidance = do_classifier_free_guidance ,
162- guidance_scale = guidance_scale ,
163- )
222+ is_cache_step = step_is_cache [step ]
223+
224+ if is_cache_step :
225+ # ── Cache step: cond-only forward + FFT frequency compensation ──
226+ w1 , w2 = step_w1w2 [step ]
227+ timestep = jnp .broadcast_to (t , bsz )
228+ noise_pred , cached_noise_cond = transformer_forward_pass_cfg_cache (
229+ graphdef ,
230+ sharded_state ,
231+ rest_of_state ,
232+ latents ,
233+ timestep ,
234+ prompt_cond_embeds ,
235+ cached_noise_cond ,
236+ cached_noise_uncond ,
237+ guidance_scale = guidance_scale ,
238+ w1 = jnp .float32 (w1 ),
239+ w2 = jnp .float32 (w2 ),
240+ )
241+
242+ elif do_cfg :
243+ # ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
244+ latents_doubled = jnp .concatenate ([latents ] * 2 )
245+ timestep = jnp .broadcast_to (t , bsz * 2 )
246+ noise_pred , cached_noise_cond , cached_noise_uncond = transformer_forward_pass_full_cfg (
247+ graphdef ,
248+ sharded_state ,
249+ rest_of_state ,
250+ latents_doubled ,
251+ timestep ,
252+ prompt_embeds_combined ,
253+ guidance_scale = guidance_scale ,
254+ )
255+
256+ else :
257+ # ── No CFG (guidance_scale <= 1.0) ──
258+ timestep = jnp .broadcast_to (t , bsz )
259+ noise_pred , latents = transformer_forward_pass (
260+ graphdef ,
261+ sharded_state ,
262+ rest_of_state ,
263+ latents ,
264+ timestep ,
265+ prompt_cond_embeds ,
266+ do_classifier_free_guidance = False ,
267+ guidance_scale = guidance_scale ,
268+ )
164269
165270 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
166271 return latents
0 commit comments