@@ -912,13 +912,15 @@ def nearest_interp(src, target_len):
912912 indices = np .round (np .linspace (0 , src_len - 1 , target_len )).astype (np .int32 )
913913 return src [indices ]
914914
915- def init_magcache (num_inference_steps , retention_ratio , mag_ratios_base ):
915+ def init_magcache (num_inference_steps , retention_ratio , mag_ratios_base , split_step = None , model_type = "T2V" ):
916916 """Initialize MagCache variables and interpolate ratios.
917917
918918 Args:
919919 num_inference_steps: Number of inference steps.
920920 retention_ratio: Retention ratio of unchanged steps.
921921 mag_ratios_base: Base magnitude ratios array or list.
922+ split_step: Step at which model switches (e.g. high -> low noise for 2.2).
923+ model_type: Pipeline mode ("T2V" or "I2V").
922924 """
923925 import numpy as np
924926
@@ -951,6 +953,8 @@ def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base):
951953 cached_residual ,
952954 skip_warmup ,
953955 mag_ratios ,
956+ split_step ,
957+ model_type ,
954958 )
955959
956960def magcache_step (
@@ -960,6 +964,10 @@ def magcache_step(
960964 magcache_thresh ,
961965 magcache_K ,
962966 skip_warmup ,
967+ split_step = None ,
968+ model_type = "T2V" ,
969+ num_steps = None ,
970+ retention_ratio = 0.2 ,
963971):
964972 """Update MagCache accumulated state and decide if to skip.
965973
@@ -970,6 +978,10 @@ def magcache_step(
970978 magcache_thresh: Error threshold.
971979 magcache_K: Max skip steps.
972980 skip_warmup: Warmup steps threshold.
981+ split_step: Optional step index where the model switches (e.g., from high to low noise).
982+ model_type: Pipeline type ("T2V" or "I2V").
983+ num_steps: Total inference steps, used to calculate post-split warmups.
984+ retention_ratio: Used to calculate post-split warmups.
973985 """
974986 import numpy as np
975987
@@ -985,8 +997,20 @@ def magcache_step(
985997 cur_mag_ratio_cond = mag_ratios [step * 2 ]
986998 cur_mag_ratio_uncond = mag_ratios [step * 2 + 1 ]
987999
1000+ use_magcache = True
1001+ if split_step is not None :
1002+ if model_type == "I2V" :
1003+ if step < int (split_step + (num_steps - split_step ) * retention_ratio ):
1004+ use_magcache = False
1005+ else :
1006+ if step < int (split_step * retention_ratio ) or (step <= ((num_steps - split_step ) * retention_ratio + split_step ) and step >= split_step ):
1007+ use_magcache = False
1008+ else :
1009+ if step < skip_warmup :
1010+ use_magcache = False
1011+
9881012 skip_blocks = False
989- if step >= skip_warmup :
1013+ if use_magcache :
9901014 new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
9911015 new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
9921016
0 commit comments