Skip to content

Commit 2a82a6e

Browse files
committed
fix and refactor
1 parent 00a19e2 commit 2a82a6e

2 files changed

Lines changed: 30 additions & 77 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ def __call__(
9696
magcache_K: Optional[int] = None,
9797
retention_ratio: Optional[float] = None,
9898
):
99+
config = getattr(self, "config", None)
99100
if magcache_thresh is None:
100-
magcache_thresh = getattr(self.config, "magcache_thresh", 0.12)
101+
magcache_thresh = getattr(config, "magcache_thresh", 0.12)
101102
if magcache_K is None:
102-
magcache_K = getattr(self.config, "magcache_K", 2)
103+
magcache_K = getattr(config, "magcache_K", 2)
103104
if retention_ratio is None:
104-
retention_ratio = getattr(self.config, "retention_ratio", 0.2)
105+
retention_ratio = getattr(config, "retention_ratio", 0.2)
105106

106107
if use_cfg_cache and guidance_scale <= 1.0:
107108
raise ValueError(
@@ -138,7 +139,7 @@ def __call__(
138139
magcache_K=magcache_K,
139140
retention_ratio=retention_ratio,
140141
height=height,
141-
mag_ratios_base=self.config.mag_ratios_base if hasattr(self.config, "mag_ratios_base") else None,
142+
mag_ratios_base=getattr(config, "mag_ratios_base", None),
142143
)
143144

144145
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -244,43 +245,23 @@ def run_inference_2_1(
244245
cached_noise_uncond = None
245246

246247
if use_magcache and do_cfg:
247-
(
248-
accumulated_ratio_cond,
249-
accumulated_ratio_uncond,
250-
accumulated_err_cond,
251-
accumulated_err_uncond,
252-
accumulated_steps_cond,
253-
accumulated_steps_uncond,
254-
cached_residual,
255-
skip_warmup,
256-
mag_ratios,
257-
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
258-
259-
for step in range(num_inference_steps):
260-
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
248+
magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
249+
accumulated_state = magcache_init[:6]
250+
cached_residual = magcache_init[6]
251+
skip_warmup = magcache_init[7]
252+
mag_ratios = magcache_init[8]
253+
254+
for step in range(num_inference_steps):
255+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
256+
257+
if use_magcache and do_cfg:
261258
timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz)
262259

263-
accumulated_state = (
264-
accumulated_ratio_cond,
265-
accumulated_ratio_uncond,
266-
accumulated_err_cond,
267-
accumulated_err_uncond,
268-
accumulated_steps_cond,
269-
accumulated_steps_uncond,
270-
)
271260
skip_blocks, accumulated_state = magcache_step(
272261
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
273262
)
274-
(
275-
accumulated_ratio_cond,
276-
accumulated_ratio_uncond,
277-
accumulated_err_cond,
278-
accumulated_err_uncond,
279-
accumulated_steps_cond,
280-
accumulated_steps_uncond,
281-
) = accumulated_state
282-
283-
outputs = transformer_forward_pass(
263+
264+
noise_pred, latents, residual_x_cur = transformer_forward_pass(
284265
graphdef,
285266
sharded_state,
286267
rest_of_state,
@@ -294,18 +275,10 @@ def run_inference_2_1(
294275
return_residual=True,
295276
)
296277

297-
noise_pred, latents_returned, residual_x_cur = outputs
298-
299278
if not skip_blocks:
300279
cached_residual = residual_x_cur
301280

302-
latents = latents_returned
303-
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
304-
return latents
305-
306-
else:
307-
for step in range(num_inference_steps):
308-
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
281+
else:
309282
is_cache_step = step_is_cache[step]
310283

311284
if is_cache_step:
@@ -351,5 +324,6 @@ def run_inference_2_1(
351324
guidance_scale=guidance_scale,
352325
)
353326

354-
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
355-
return latents
327+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
328+
329+
return latents

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,13 @@ def __call__(
154154
magcache_K: Optional[int] = None,
155155
retention_ratio: Optional[float] = None,
156156
):
157+
config = getattr(self, "config", None)
157158
if magcache_thresh is None:
158-
magcache_thresh = getattr(self.config, "magcache_thresh", 0.04)
159+
magcache_thresh = getattr(config, "magcache_thresh", 0.04)
159160
if magcache_K is None:
160-
magcache_K = getattr(self.config, "magcache_K", 2)
161+
magcache_K = getattr(config, "magcache_K", 2)
161162
if retention_ratio is None:
162-
retention_ratio = getattr(self.config, "retention_ratio", 0.2)
163+
retention_ratio = getattr(config, "retention_ratio", 0.2)
163164

164165
height = height or self.config.height
165166
width = width or self.config.width
@@ -291,17 +292,11 @@ def run_inference_2_1_i2v(
291292
do_cfg = guidance_scale > 1.0
292293

293294
if use_magcache and do_cfg:
294-
(
295-
accumulated_ratio_cond,
296-
accumulated_ratio_uncond,
297-
accumulated_err_cond,
298-
accumulated_err_uncond,
299-
accumulated_steps_cond,
300-
accumulated_steps_uncond,
301-
cached_residual,
302-
skip_warmup,
303-
mag_ratios,
304-
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
295+
magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
296+
accumulated_state = magcache_init[:6]
297+
cached_residual = magcache_init[6]
298+
skip_warmup = magcache_init[7]
299+
mag_ratios = magcache_init[8]
305300

306301
if do_cfg:
307302
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
@@ -317,25 +312,9 @@ def run_inference_2_1_i2v(
317312

318313
skip_blocks = False
319314
if use_magcache and do_cfg:
320-
accumulated_state = (
321-
accumulated_ratio_cond,
322-
accumulated_ratio_uncond,
323-
accumulated_err_cond,
324-
accumulated_err_uncond,
325-
accumulated_steps_cond,
326-
accumulated_steps_uncond,
327-
)
328315
skip_blocks, accumulated_state = magcache_step(
329316
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
330317
)
331-
(
332-
accumulated_ratio_cond,
333-
accumulated_ratio_uncond,
334-
accumulated_err_cond,
335-
accumulated_err_uncond,
336-
accumulated_steps_cond,
337-
accumulated_steps_uncond,
338-
) = accumulated_state
339318

340319
latents_input = latents
341320
if do_cfg:

0 commit comments

Comments
 (0)