|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp |
| 15 | +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp, init_magcache, magcache_step |
16 | 16 | from ...models.wan.transformers.transformer_wan import WanModel |
17 | 17 | from typing import List, Union, Optional, Any |
18 | 18 | from ...pyconfig import HyperParameters |
@@ -131,6 +131,7 @@ def __call__( |
131 | 131 | magcache_K=magcache_K, |
132 | 132 | retention_ratio=retention_ratio, |
133 | 133 | height=height, |
| 134 | + mag_ratios_base=self.config.mag_ratios_base if hasattr(self.config, "mag_ratios_base") else None, |
134 | 135 | ) |
135 | 136 |
|
136 | 137 | with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): |
@@ -163,6 +164,7 @@ def run_inference_2_1( |
163 | 164 | magcache_K: int = 2, |
164 | 165 | retention_ratio: float = 0.2, |
165 | 166 | height: int = 480, |
| 167 | + mag_ratios_base: Optional[List[float]] = None, |
166 | 168 | ): |
167 | 169 | """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. |
168 | 170 |
|
@@ -233,59 +235,42 @@ def run_inference_2_1( |
233 | 235 | cached_noise_cond = None |
234 | 236 | cached_noise_uncond = None |
235 | 237 |
|
236 | | - if use_magcache and do_classifier_free_guidance: |
237 | | - # ── MagCache Execution Path ── |
238 | | - accumulated_ratio_cond = 1.0 |
239 | | - accumulated_ratio_uncond = 1.0 |
240 | | - accumulated_err_cond = 0.0 |
241 | | - accumulated_err_uncond = 0.0 |
242 | | - accumulated_steps_cond = 0 |
243 | | - accumulated_steps_uncond = 0 |
244 | | - cached_residual = None |
245 | | - |
246 | | - skip_warmup = int(num_inference_steps * retention_ratio) |
247 | | - |
248 | | - # 14B Ratios |
249 | | - mag_ratios_base = np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]) |
250 | | - |
251 | | - if len(mag_ratios_base) != num_inference_steps * 2: |
252 | | - mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps) |
253 | | - mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps) |
254 | | - mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1) |
255 | | - else: |
256 | | - mag_ratios = mag_ratios_base |
| 238 | + if use_magcache and do_cfg: |
| 239 | + ( |
| 240 | + accumulated_ratio_cond, |
| 241 | + accumulated_ratio_uncond, |
| 242 | + accumulated_err_cond, |
| 243 | + accumulated_err_uncond, |
| 244 | + accumulated_steps_cond, |
| 245 | + accumulated_steps_uncond, |
| 246 | + cached_residual, |
| 247 | + skip_warmup, |
| 248 | + mag_ratios, |
| 249 | + ) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) |
257 | 250 |
|
258 | 251 | for step in range(num_inference_steps): |
259 | 252 | t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] |
260 | 253 | timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz) |
261 | 254 |
|
262 | | - skip_blocks = False |
263 | | - if step >= skip_warmup: |
264 | | - ratio_cond = mag_ratios[step * 2] |
265 | | - ratio_uncond = mag_ratios[step * 2 + 1] |
266 | | - |
267 | | - new_ratio_cond = accumulated_ratio_cond * ratio_cond |
268 | | - new_ratio_uncond = accumulated_ratio_uncond * ratio_uncond |
269 | | - |
270 | | - err_cond = np.abs(1.0 - new_ratio_cond) |
271 | | - err_uncond = np.abs(1.0 - new_ratio_uncond) |
272 | | - |
273 | | - if (accumulated_err_cond + err_cond < magcache_thresh and accumulated_steps_cond < magcache_K and |
274 | | - accumulated_err_uncond + err_uncond < magcache_thresh and accumulated_steps_uncond < magcache_K): |
275 | | - skip_blocks = True |
276 | | - accumulated_ratio_cond = new_ratio_cond |
277 | | - accumulated_ratio_uncond = new_ratio_uncond |
278 | | - accumulated_err_cond += err_cond |
279 | | - accumulated_err_uncond += err_uncond |
280 | | - accumulated_steps_cond += 1 |
281 | | - accumulated_steps_uncond += 1 |
282 | | - else: |
283 | | - accumulated_ratio_cond = 1.0 |
284 | | - accumulated_ratio_uncond = 1.0 |
285 | | - accumulated_err_cond = 0.0 |
286 | | - accumulated_err_uncond = 0.0 |
287 | | - accumulated_steps_cond = 0 |
288 | | - accumulated_steps_uncond = 0 |
| 255 | + accumulated_state = ( |
| 256 | + accumulated_ratio_cond, |
| 257 | + accumulated_ratio_uncond, |
| 258 | + accumulated_err_cond, |
| 259 | + accumulated_err_uncond, |
| 260 | + accumulated_steps_cond, |
| 261 | + accumulated_steps_uncond, |
| 262 | + ) |
| 263 | + skip_blocks, accumulated_state = magcache_step( |
| 264 | + step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup |
| 265 | + ) |
| 266 | + ( |
| 267 | + accumulated_ratio_cond, |
| 268 | + accumulated_ratio_uncond, |
| 269 | + accumulated_err_cond, |
| 270 | + accumulated_err_uncond, |
| 271 | + accumulated_steps_cond, |
| 272 | + accumulated_steps_uncond, |
| 273 | + ) = accumulated_state |
289 | 274 |
|
290 | 275 | outputs = transformer_forward_pass( |
291 | 276 | graphdef, |
|
0 commit comments