Skip to content

Commit d4d4501

Browse files
committed
feat: fori_loop-based denoising for WAN inference
Replace Python denoising loop with jax.lax.fori_loop for the non-cache paths of WAN 2.1 and WAN 2.2 pipelines (both T2V and I2V). This compiles the entire denoising loop as a single XLA program, eliminating per-step Python dispatch overhead (~100 dispatches for 50 steps). WAN 2.1 (T2V & I2V): single transformer + fori_loop WAN 2.2 (T2V & I2V): dual-transformer selection via jax.lax.cond inside fori_loop (both transformers share graphdef) The fori_loop path is used when no caching (CFG-Cache, MagCache, SenCache) is enabled. Existing Python loop paths are preserved as fallback for cache-enabled configurations. Scheduler state is pre-initialized with concrete values (step_index=0, last_sample=zeros, begin_index=0) to ensure consistent pytree structure across all fori_loop iterations.
1 parent ceca471 commit d4d4501

4 files changed

Lines changed: 467 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, init_magcache, magcache_step
16+
from ...schedulers.scheduling_unipc_multistep_flax import UniPCMultistepSchedulerState
1617
from ...models.wan.transformers.transformer_wan import WanModel
1718
from typing import List, Union, Optional
1819
from ...pyconfig import HyperParameters
@@ -127,6 +128,31 @@ def __call__(
127128

128129
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
129130

131+
# Use fori_loop path when no caching is enabled for reduced dispatch overhead.
132+
use_fori = not use_cfg_cache and not use_magcache
133+
134+
if use_fori:
135+
do_cfg = guidance_scale > 1.0
136+
p_run_inference = partial(
137+
run_inference_fori_2_1,
138+
do_cfg=do_cfg,
139+
guidance_scale=guidance_scale,
140+
num_inference_steps=num_inference_steps,
141+
scheduler=self.scheduler,
142+
)
143+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
144+
latents = p_run_inference(
145+
graphdef=graphdef,
146+
sharded_state=state,
147+
rest_of_state=rest_of_state,
148+
latents=latents,
149+
prompt_embeds=prompt_embeds,
150+
negative_prompt_embeds=negative_prompt_embeds,
151+
scheduler_state=scheduler_state,
152+
)
153+
latents = self._denormalize_latents(latents)
154+
return self._decode_latents_to_video(latents)
155+
130156
p_run_inference = partial(
131157
run_inference_2_1,
132158
guidance_scale=guidance_scale,
@@ -155,6 +181,77 @@ def __call__(
155181
return self._decode_latents_to_video(latents)
156182

157183

184+
@partial(jax.jit, static_argnames=("do_cfg", "guidance_scale", "num_inference_steps", "scheduler"))
185+
def run_inference_fori_2_1(
186+
graphdef,
187+
sharded_state,
188+
rest_of_state,
189+
latents: jnp.array,
190+
prompt_embeds: jnp.array,
191+
negative_prompt_embeds: jnp.array,
192+
scheduler_state: UniPCMultistepSchedulerState,
193+
do_cfg: bool,
194+
guidance_scale: float,
195+
num_inference_steps: int,
196+
scheduler: FlaxUniPCMultistepScheduler,
197+
):
198+
"""Denoising loop for WAN 2.1 T2V using jax.lax.fori_loop.
199+
200+
The entire denoising loop runs as a single XLA program, eliminating
201+
per-step Python dispatch overhead. This path is used when no caching
202+
(CFG-Cache, MagCache) is enabled.
203+
"""
204+
bsz = latents.shape[0]
205+
206+
# Pre-combine embeddings for CFG (static at trace time).
207+
if do_cfg:
208+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
209+
else:
210+
prompt_embeds_combined = prompt_embeds
211+
212+
# Pre-initialize scheduler state with concrete values so the pytree
213+
# structure is consistent across all fori_loop iterations.
214+
# - step_index: must be int (not None) so scheduler.step skips _init_step_index
215+
# - last_sample: must be array (not None) for consistent pytree; the
216+
# corrector is still skipped on step 0 because step_index > 0 is False
217+
# - begin_index: set to 0 so it's a concrete int rather than None
218+
scheduler_state = scheduler_state.replace(
219+
step_index=0,
220+
last_sample=jnp.zeros_like(latents),
221+
begin_index=0,
222+
)
223+
224+
def body_fn(step, carry):
225+
latents, sched_state = carry
226+
t = sched_state.timesteps[step]
227+
228+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
229+
230+
if do_cfg:
231+
latents_input = jnp.concatenate([latents] * 2)
232+
timestep = jnp.broadcast_to(t, (bsz * 2,))
233+
else:
234+
latents_input = latents
235+
timestep = jnp.broadcast_to(t, (bsz,))
236+
237+
noise_pred = wan_transformer(
238+
hidden_states=latents_input,
239+
timestep=timestep,
240+
encoder_hidden_states=prompt_embeds_combined,
241+
)
242+
243+
if do_cfg:
244+
noise_cond = noise_pred[:bsz]
245+
noise_uncond = noise_pred[bsz:]
246+
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
247+
248+
latents, sched_state = scheduler.step(sched_state, noise_pred, t, latents).to_tuple()
249+
return latents, sched_state
250+
251+
latents, _ = jax.lax.fori_loop(0, num_inference_steps, body_fn, (latents, scheduler_state))
252+
return latents
253+
254+
158255
def run_inference_2_1(
159256
graphdef,
160257
sharded_state,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
16+
from ...schedulers.scheduling_unipc_multistep_flax import UniPCMultistepSchedulerState
1617
from ...models.wan.transformers.transformer_wan import WanModel
1718
from typing import List, Union, Optional
1819
from ...pyconfig import HyperParameters
@@ -150,6 +151,35 @@ def __call__(
150151

151152
boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps
152153

154+
# Use fori_loop path when no caching is enabled for reduced dispatch overhead.
155+
use_fori = not use_cfg_cache and not use_sen_cache
156+
157+
if use_fori:
158+
do_cfg = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
159+
p_run_inference = partial(
160+
run_inference_fori_2_2,
161+
do_cfg=do_cfg,
162+
guidance_scale_low=guidance_scale_low,
163+
guidance_scale_high=guidance_scale_high,
164+
boundary=boundary_timestep,
165+
num_inference_steps=num_inference_steps,
166+
scheduler=self.scheduler,
167+
)
168+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
169+
latents = p_run_inference(
170+
graphdef=high_noise_graphdef,
171+
low_noise_state=low_noise_state,
172+
low_noise_rest=low_noise_rest,
173+
high_noise_state=high_noise_state,
174+
high_noise_rest=high_noise_rest,
175+
latents=latents,
176+
prompt_embeds=prompt_embeds,
177+
negative_prompt_embeds=negative_prompt_embeds,
178+
scheduler_state=scheduler_state,
179+
)
180+
latents = self._denormalize_latents(latents)
181+
return self._decode_latents_to_video(latents)
182+
153183
p_run_inference = partial(
154184
run_inference_2_2,
155185
guidance_scale_low=guidance_scale_low,
@@ -179,6 +209,97 @@ def __call__(
179209
return self._decode_latents_to_video(latents)
180210

181211

212+
@partial(jax.jit, static_argnames=("do_cfg", "guidance_scale_low", "guidance_scale_high", "boundary", "num_inference_steps", "scheduler"))
213+
def run_inference_fori_2_2(
214+
graphdef,
215+
low_noise_state,
216+
low_noise_rest,
217+
high_noise_state,
218+
high_noise_rest,
219+
latents: jnp.array,
220+
prompt_embeds: jnp.array,
221+
negative_prompt_embeds: jnp.array,
222+
scheduler_state: UniPCMultistepSchedulerState,
223+
do_cfg: bool,
224+
guidance_scale_low: float,
225+
guidance_scale_high: float,
226+
boundary: float,
227+
num_inference_steps: int,
228+
scheduler: FlaxUniPCMultistepScheduler,
229+
):
230+
"""Denoising loop for WAN 2.2 T2V using jax.lax.fori_loop.
231+
232+
The entire denoising loop runs as a single XLA program, eliminating
233+
per-step Python dispatch overhead. Dual-transformer selection
234+
(high-noise vs low-noise based on boundary timestep) is handled
235+
inside the loop using jax.lax.cond.
236+
237+
Both transformers share the same architecture (identical graphdef),
238+
so a single graphdef is used with jax.lax.cond selecting between
239+
the two weight states per step.
240+
"""
241+
bsz = latents.shape[0]
242+
243+
# Pre-combine embeddings for CFG (static at trace time).
244+
if do_cfg:
245+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
246+
else:
247+
prompt_embeds_combined = prompt_embeds
248+
249+
# Pre-initialize scheduler state with concrete values for fori_loop.
250+
scheduler_state = scheduler_state.replace(
251+
step_index=0,
252+
last_sample=jnp.zeros_like(latents),
253+
begin_index=0,
254+
)
255+
256+
def body_fn(step, carry):
257+
latents, sched_state = carry
258+
t = sched_state.timesteps[step]
259+
use_high = t >= boundary
260+
261+
# Select guidance scale based on transformer phase.
262+
guidance_scale = jnp.where(use_high, guidance_scale_high, guidance_scale_low)
263+
264+
if do_cfg:
265+
latents_input = jnp.concatenate([latents] * 2)
266+
timestep = jnp.broadcast_to(t, (bsz * 2,))
267+
else:
268+
latents_input = latents
269+
timestep = jnp.broadcast_to(t, (bsz,))
270+
271+
# Select transformer weights via jax.lax.cond.
272+
# Both branches trace through the same graphdef with different states.
273+
def high_noise_forward():
274+
transformer = nnx.merge(graphdef, high_noise_state, high_noise_rest)
275+
return transformer(
276+
hidden_states=latents_input,
277+
timestep=timestep,
278+
encoder_hidden_states=prompt_embeds_combined,
279+
)
280+
281+
def low_noise_forward():
282+
transformer = nnx.merge(graphdef, low_noise_state, low_noise_rest)
283+
return transformer(
284+
hidden_states=latents_input,
285+
timestep=timestep,
286+
encoder_hidden_states=prompt_embeds_combined,
287+
)
288+
289+
noise_pred = jax.lax.cond(use_high, high_noise_forward, low_noise_forward)
290+
291+
if do_cfg:
292+
noise_cond = noise_pred[:bsz]
293+
noise_uncond = noise_pred[bsz:]
294+
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
295+
296+
latents, sched_state = scheduler.step(sched_state, noise_pred, t, latents).to_tuple()
297+
return latents, sched_state
298+
299+
latents, _ = jax.lax.fori_loop(0, num_inference_steps, body_fn, (latents, scheduler_state))
300+
return latents
301+
302+
182303
def run_inference_2_2(
183304
low_noise_graphdef,
184305
low_noise_state,

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from maxdiffusion import max_logging
1616
from maxdiffusion.image_processor import PipelineImageInput
1717
from .wan_pipeline import WanPipeline, transformer_forward_pass, init_magcache, magcache_step
18+
from ...schedulers.scheduling_unipc_multistep_flax import UniPCMultistepSchedulerState
1819
from ...models.wan.transformers.transformer_wan import WanModel
1920
from typing import List, Union, Optional, Tuple
2021
from ...pyconfig import HyperParameters
@@ -236,6 +237,37 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
236237
if first_frame_mask is not None:
237238
first_frame_mask = jax.device_put(first_frame_mask, data_sharding)
238239

240+
# Use fori_loop path when no caching is enabled for reduced dispatch overhead.
241+
use_fori = not use_magcache
242+
243+
if use_fori:
244+
do_cfg = guidance_scale > 1.0
245+
p_run_inference = partial(
246+
run_inference_fori_2_1_i2v,
247+
do_cfg=do_cfg,
248+
guidance_scale=guidance_scale,
249+
num_inference_steps=num_inference_steps,
250+
scheduler=self.scheduler,
251+
)
252+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
253+
latents = p_run_inference(
254+
graphdef=graphdef,
255+
sharded_state=state,
256+
rest_of_state=rest_of_state,
257+
latents=latents,
258+
condition=condition,
259+
prompt_embeds=prompt_embeds,
260+
negative_prompt_embeds=negative_prompt_embeds,
261+
image_embeds=image_embeds,
262+
scheduler_state=scheduler_state,
263+
)
264+
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
265+
latents = self._denormalize_latents(latents)
266+
267+
if output_type == "latent":
268+
return latents
269+
return self._decode_latents_to_video(latents)
270+
239271
p_run_inference = partial(
240272
run_inference_2_1_i2v,
241273
graphdef=graphdef,
@@ -269,6 +301,85 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
269301
return self._decode_latents_to_video(latents)
270302

271303

304+
@partial(jax.jit, static_argnames=("do_cfg", "guidance_scale", "num_inference_steps", "scheduler"))
305+
def run_inference_fori_2_1_i2v(
306+
graphdef,
307+
sharded_state,
308+
rest_of_state,
309+
latents: jnp.array,
310+
condition: jnp.array,
311+
prompt_embeds: jnp.array,
312+
negative_prompt_embeds: jnp.array,
313+
image_embeds: jnp.array,
314+
scheduler_state: UniPCMultistepSchedulerState,
315+
do_cfg: bool,
316+
guidance_scale: float,
317+
num_inference_steps: int,
318+
scheduler: FlaxUniPCMultistepScheduler,
319+
):
320+
"""Denoising loop for WAN 2.1 I2V using jax.lax.fori_loop.
321+
322+
The entire denoising loop runs as a single XLA program, eliminating
323+
per-step Python dispatch overhead. I2V-specific: condition is concatenated
324+
with latents and image_embeds is passed to the transformer.
325+
"""
326+
bsz = latents.shape[0]
327+
328+
# Pre-combine embeddings for CFG (static at trace time).
329+
if do_cfg:
330+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
331+
image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0)
332+
condition_combined = jnp.concatenate([condition] * 2)
333+
else:
334+
prompt_embeds_combined = prompt_embeds
335+
image_embeds_combined = image_embeds
336+
condition_combined = condition
337+
338+
# Pre-initialize scheduler state for fori_loop.
339+
scheduler_state = scheduler_state.replace(
340+
step_index=0,
341+
last_sample=jnp.zeros_like(latents),
342+
begin_index=0,
343+
)
344+
345+
def body_fn(step, carry):
346+
latents, sched_state = carry
347+
t = sched_state.timesteps[step]
348+
349+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
350+
351+
if do_cfg:
352+
latents_input = jnp.concatenate([latents] * 2)
353+
timestep = jnp.broadcast_to(t, (bsz * 2,))
354+
else:
355+
latents_input = latents
356+
timestep = jnp.broadcast_to(t, (bsz,))
357+
358+
# Concatenate condition and transpose BFHWC -> BCFHW for transformer.
359+
latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1)
360+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
361+
362+
noise_pred = wan_transformer(
363+
hidden_states=latent_model_input,
364+
timestep=timestep,
365+
encoder_hidden_states=prompt_embeds_combined,
366+
encoder_hidden_states_image=image_embeds_combined,
367+
)
368+
369+
if do_cfg:
370+
noise_cond = noise_pred[:bsz]
371+
noise_uncond = noise_pred[bsz:]
372+
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
373+
374+
# Transpose BCFHW -> BFHWC back to latent space.
375+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
376+
latents, sched_state = scheduler.step(sched_state, noise_pred, t, latents).to_tuple()
377+
return latents, sched_state
378+
379+
latents, _ = jax.lax.fori_loop(0, num_inference_steps, body_fn, (latents, scheduler_state))
380+
return latents
381+
382+
272383
def run_inference_2_1_i2v(
273384
graphdef,
274385
sharded_state,

0 commit comments

Comments
 (0)