Skip to content

Commit 7d4d8ec

Browse files
committed
refactor
1 parent d2a5e21 commit 7d4d8ec

9 files changed

Lines changed: 253 additions & 192 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ use_magcache: True
332332
magcache_thresh: 0.12
333333
magcache_K: 2
334334
retention_ratio: 0.2
335+
mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99279, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]
335336

336337
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
337338
guidance_rescale: 0.0

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ use_magcache: True
308308
magcache_thresh: 0.12
309309
magcache_K: 2
310310
retention_ratio: 0.2
311+
mag_ratios_base: [1.0, 1.0, 1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.9977, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
311312

312313
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
313314
# when predicted output change (based on accumulated latent/timestep drift) is small

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ use_magcache: True
292292
magcache_thresh: 0.12
293293
magcache_K: 2
294294
retention_ratio: 0.2
295+
mag_ratios_base_720p: [1.0, 1.0, 0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]
296+
mag_ratios_base_480p: [1.0, 1.0, 0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]
295297

296298
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
297299
guidance_rescale: 0.0

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ use_magcache: True
304304
magcache_thresh: 0.12
305305
magcache_K: 2
306306
retention_ratio: 0.2
307+
mag_ratios_base: [1.0, 1.0, 0.99512, 0.99559, 0.99559, 0.99561, 0.99595, 0.99577, 0.99512, 0.99512, 0.99546, 0.99534, 0.99543, 0.99531, 0.99496, 0.99491, 0.99504, 0.99499, 0.99444, 0.99449, 0.99481, 0.99481, 0.99435, 0.99435, 0.9943, 0.99431, 0.99411, 0.99406, 0.99373, 0.99376, 0.99413, 0.99405, 0.99363, 0.99359, 0.99335, 0.99331, 0.99244, 0.99243, 0.99229, 0.99229, 0.99239, 0.99236, 0.99163, 0.9916, 0.99149, 0.99151, 0.99191, 0.99192, 0.9898, 0.98981, 0.9899, 0.98987, 0.98849, 0.98849, 0.98846, 0.98846, 0.98861, 0.98861, 0.9874, 0.98738, 0.98588, 0.98589, 0.98539, 0.98534, 0.98444, 0.98439, 0.9831, 0.98309, 0.98119, 0.98118, 0.98001, 0.98, 0.97862, 0.97859, 0.97555, 0.97558, 0.97392, 0.97388, 0.97152, 0.97145, 0.96871, 0.9687, 0.96435, 0.96434, 0.96129, 0.96127, 0.95639, 0.95638, 0.95176, 0.95175, 0.94446, 0.94452, 0.93972, 0.93974, 0.93575, 0.9359, 0.93537, 0.93552, 0.96655, 0.96616]
307308

308309
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
309310
guidance_rescale: 0.0

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,4 +910,116 @@ def nearest_interp(src, target_len):
910910
return np.array([src[-1]])
911911
import numpy as np
912912
indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32)
913-
return src[indices]
913+
return src[indices]
914+
915+
def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base):
916+
"""Initialize MagCache variables and interpolate ratios.
917+
918+
Args:
919+
num_inference_steps: Number of inference steps.
920+
retention_ratio: Retention ratio of unchanged steps.
921+
mag_ratios_base: Base magnitude ratios array or list.
922+
"""
923+
import numpy as np
924+
925+
accumulated_ratio_cond = 1.0
926+
accumulated_ratio_uncond = 1.0
927+
accumulated_err_cond = 0.0
928+
accumulated_err_uncond = 0.0
929+
accumulated_steps_cond = 0
930+
accumulated_steps_uncond = 0
931+
cached_residual = None
932+
933+
skip_warmup = int(num_inference_steps * retention_ratio)
934+
935+
mag_ratios_base = np.array(mag_ratios_base)
936+
937+
if len(mag_ratios_base) != num_inference_steps * 2:
938+
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
939+
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
940+
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
941+
else:
942+
mag_ratios = mag_ratios_base
943+
944+
return (
945+
accumulated_ratio_cond,
946+
accumulated_ratio_uncond,
947+
accumulated_err_cond,
948+
accumulated_err_uncond,
949+
accumulated_steps_cond,
950+
accumulated_steps_uncond,
951+
cached_residual,
952+
skip_warmup,
953+
mag_ratios,
954+
)
955+
956+
def magcache_step(
957+
step,
958+
mag_ratios,
959+
accumulated_state,
960+
magcache_thresh,
961+
magcache_K,
962+
skip_warmup,
963+
):
964+
"""Update MagCache accumulated state and decide if to skip.
965+
966+
Args:
967+
step: Current inference step.
968+
mag_ratios: Interpolated magnitude ratios array.
969+
accumulated_state: Tuple containing accumulated variables.
970+
magcache_thresh: Error threshold.
971+
magcache_K: Max skip steps.
972+
skip_warmup: Warmup steps threshold.
973+
"""
974+
import numpy as np
975+
976+
(
977+
accumulated_ratio_cond,
978+
accumulated_ratio_uncond,
979+
accumulated_err_cond,
980+
accumulated_err_uncond,
981+
accumulated_steps_cond,
982+
accumulated_steps_uncond,
983+
) = accumulated_state
984+
985+
cur_mag_ratio_cond = mag_ratios[step * 2]
986+
cur_mag_ratio_uncond = mag_ratios[step * 2 + 1]
987+
988+
skip_blocks = False
989+
if step >= skip_warmup:
990+
new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
991+
new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
992+
993+
err_cond = np.abs(1.0 - new_ratio_cond)
994+
err_uncond = np.abs(1.0 - new_ratio_uncond)
995+
996+
if (
997+
accumulated_err_cond + err_cond < magcache_thresh
998+
and accumulated_steps_cond < magcache_K
999+
and accumulated_err_uncond + err_uncond < magcache_thresh
1000+
and accumulated_steps_uncond < magcache_K
1001+
):
1002+
skip_blocks = True
1003+
accumulated_ratio_cond = new_ratio_cond
1004+
accumulated_ratio_uncond = new_ratio_uncond
1005+
accumulated_err_cond += err_cond
1006+
accumulated_err_uncond += err_uncond
1007+
accumulated_steps_cond += 1
1008+
accumulated_steps_uncond += 1
1009+
else:
1010+
accumulated_ratio_cond = 1.0
1011+
accumulated_ratio_uncond = 1.0
1012+
accumulated_err_cond = 0.0
1013+
accumulated_err_uncond = 0.0
1014+
accumulated_steps_cond = 0
1015+
accumulated_steps_uncond = 0
1016+
1017+
new_state = (
1018+
accumulated_ratio_cond,
1019+
accumulated_ratio_uncond,
1020+
accumulated_err_cond,
1021+
accumulated_err_uncond,
1022+
accumulated_steps_cond,
1023+
accumulated_steps_uncond,
1024+
)
1025+
return skip_blocks, new_state

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp, init_magcache, magcache_step
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional, Any
1818
from ...pyconfig import HyperParameters
@@ -131,6 +131,7 @@ def __call__(
131131
magcache_K=magcache_K,
132132
retention_ratio=retention_ratio,
133133
height=height,
134+
mag_ratios_base=self.config.mag_ratios_base if hasattr(self.config, "mag_ratios_base") else None,
134135
)
135136

136137
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -163,6 +164,7 @@ def run_inference_2_1(
163164
magcache_K: int = 2,
164165
retention_ratio: float = 0.2,
165166
height: int = 480,
167+
mag_ratios_base: Optional[List[float]] = None,
166168
):
167169
"""Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
168170
@@ -233,59 +235,42 @@ def run_inference_2_1(
233235
cached_noise_cond = None
234236
cached_noise_uncond = None
235237

236-
if use_magcache and do_classifier_free_guidance:
237-
# ── MagCache Execution Path ──
238-
accumulated_ratio_cond = 1.0
239-
accumulated_ratio_uncond = 1.0
240-
accumulated_err_cond = 0.0
241-
accumulated_err_uncond = 0.0
242-
accumulated_steps_cond = 0
243-
accumulated_steps_uncond = 0
244-
cached_residual = None
245-
246-
skip_warmup = int(num_inference_steps * retention_ratio)
247-
248-
# 14B Ratios
249-
mag_ratios_base = np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189])
250-
251-
if len(mag_ratios_base) != num_inference_steps * 2:
252-
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
253-
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
254-
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
255-
else:
256-
mag_ratios = mag_ratios_base
238+
if use_magcache and do_cfg:
239+
(
240+
accumulated_ratio_cond,
241+
accumulated_ratio_uncond,
242+
accumulated_err_cond,
243+
accumulated_err_uncond,
244+
accumulated_steps_cond,
245+
accumulated_steps_uncond,
246+
cached_residual,
247+
skip_warmup,
248+
mag_ratios,
249+
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
257250

258251
for step in range(num_inference_steps):
259252
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
260253
timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz)
261254

262-
skip_blocks = False
263-
if step >= skip_warmup:
264-
ratio_cond = mag_ratios[step * 2]
265-
ratio_uncond = mag_ratios[step * 2 + 1]
266-
267-
new_ratio_cond = accumulated_ratio_cond * ratio_cond
268-
new_ratio_uncond = accumulated_ratio_uncond * ratio_uncond
269-
270-
err_cond = np.abs(1.0 - new_ratio_cond)
271-
err_uncond = np.abs(1.0 - new_ratio_uncond)
272-
273-
if (accumulated_err_cond + err_cond < magcache_thresh and accumulated_steps_cond < magcache_K and
274-
accumulated_err_uncond + err_uncond < magcache_thresh and accumulated_steps_uncond < magcache_K):
275-
skip_blocks = True
276-
accumulated_ratio_cond = new_ratio_cond
277-
accumulated_ratio_uncond = new_ratio_uncond
278-
accumulated_err_cond += err_cond
279-
accumulated_err_uncond += err_uncond
280-
accumulated_steps_cond += 1
281-
accumulated_steps_uncond += 1
282-
else:
283-
accumulated_ratio_cond = 1.0
284-
accumulated_ratio_uncond = 1.0
285-
accumulated_err_cond = 0.0
286-
accumulated_err_uncond = 0.0
287-
accumulated_steps_cond = 0
288-
accumulated_steps_uncond = 0
255+
accumulated_state = (
256+
accumulated_ratio_cond,
257+
accumulated_ratio_uncond,
258+
accumulated_err_cond,
259+
accumulated_err_uncond,
260+
accumulated_steps_cond,
261+
accumulated_steps_uncond,
262+
)
263+
skip_blocks, accumulated_state = magcache_step(
264+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
265+
)
266+
(
267+
accumulated_ratio_cond,
268+
accumulated_ratio_uncond,
269+
accumulated_err_cond,
270+
accumulated_err_uncond,
271+
accumulated_steps_cond,
272+
accumulated_steps_uncond,
273+
) = accumulated_state
289274

290275
outputs = transformer_forward_pass(
291276
graphdef,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp, init_magcache, magcache_step
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -169,6 +169,7 @@ def __call__(
169169
magcache_K=magcache_K,
170170
retention_ratio=retention_ratio,
171171
height=height,
172+
mag_ratios_base=self.config.mag_ratios_base if hasattr(self.config, "mag_ratios_base") else None,
172173
)
173174

174175
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -210,6 +211,7 @@ def run_inference_2_2(
210211
magcache_K: int = 2,
211212
retention_ratio: float = 0.2,
212213
height: int = 480,
214+
mag_ratios_base: Optional[List[float]] = None,
213215
):
214216
"""Denoising loop for WAN 2.2 T2V with optional caching acceleration.
215217
@@ -452,59 +454,44 @@ def run_inference_2_2(
452454

453455
# ── MagCache path ──
454456
if use_magcache and do_classifier_free_guidance:
455-
accumulated_ratio_cond = 1.0
456-
accumulated_ratio_uncond = 1.0
457-
accumulated_err_cond = 0.0
458-
accumulated_err_uncond = 0.0
459-
accumulated_steps_cond = 0
460-
accumulated_steps_uncond = 0
461-
cached_residual = None
462-
463-
skip_warmup = int(num_inference_steps * retention_ratio)
464-
465-
# Pre-calculated 2.2 T2V ratios
466-
mag_ratios_base = np.array([1.0]*2+[1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181])
467-
468-
if len(mag_ratios_base) != num_inference_steps * 2:
469-
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
470-
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
471-
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
472-
else:
473-
mag_ratios = mag_ratios_base
457+
(
458+
accumulated_ratio_cond,
459+
accumulated_ratio_uncond,
460+
accumulated_err_cond,
461+
accumulated_err_uncond,
462+
accumulated_steps_cond,
463+
accumulated_steps_uncond,
464+
cached_residual,
465+
skip_warmup,
466+
mag_ratios,
467+
) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
474468

475469
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
476470
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
477471
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
478472

479473
for step in range(num_inference_steps):
480474
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
481-
cur_mag_ratio_cond = mag_ratios[step*2]
482-
cur_mag_ratio_uncond = mag_ratios[step*2+1]
483-
484-
skip_blocks = False
485-
if step >= skip_warmup:
486-
new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
487-
new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
488-
489-
err_cond = np.abs(1.0 - new_ratio_cond)
490-
err_uncond = np.abs(1.0 - new_ratio_uncond)
491-
492-
if (accumulated_err_cond + err_cond < magcache_thresh and accumulated_steps_cond < magcache_K and
493-
accumulated_err_uncond + err_uncond < magcache_thresh and accumulated_steps_uncond < magcache_K):
494-
skip_blocks = True
495-
accumulated_ratio_cond = new_ratio_cond
496-
accumulated_ratio_uncond = new_ratio_uncond
497-
accumulated_err_cond += err_cond
498-
accumulated_err_uncond += err_uncond
499-
accumulated_steps_cond += 1
500-
accumulated_steps_uncond += 1
501-
else:
502-
accumulated_ratio_cond = 1.0
503-
accumulated_ratio_uncond = 1.0
504-
accumulated_err_cond = 0.0
505-
accumulated_err_uncond = 0.0
506-
accumulated_steps_cond = 0
507-
accumulated_steps_uncond = 0
475+
476+
accumulated_state = (
477+
accumulated_ratio_cond,
478+
accumulated_ratio_uncond,
479+
accumulated_err_cond,
480+
accumulated_err_uncond,
481+
accumulated_steps_cond,
482+
accumulated_steps_uncond,
483+
)
484+
skip_blocks, accumulated_state = magcache_step(
485+
step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup
486+
)
487+
(
488+
accumulated_ratio_cond,
489+
accumulated_ratio_uncond,
490+
accumulated_err_cond,
491+
accumulated_err_uncond,
492+
accumulated_steps_cond,
493+
accumulated_steps_uncond,
494+
) = accumulated_state
508495

509496
if step_uses_high[step]:
510497
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest

0 commit comments

Comments
 (0)