@@ -912,15 +912,13 @@ def nearest_interp(src, target_len):
912912 indices = np .round (np .linspace (0 , src_len - 1 , target_len )).astype (np .int32 )
913913 return src [indices ]
914914
915- def init_magcache (num_inference_steps , retention_ratio , mag_ratios_base , split_step = None , model_type = "T2V" ):
915+ def init_magcache (num_inference_steps , retention_ratio , mag_ratios_base ):
916916 """Initialize MagCache variables and interpolate ratios.
917917
918918 Args:
919919 num_inference_steps: Number of inference steps.
920920 retention_ratio: Retention ratio of unchanged steps.
921921 mag_ratios_base: Base magnitude ratios array or list.
922- split_step: Step at which model switches (e.g. high -> low noise for 2.2).
923- model_type: Pipeline mode ("T2V" or "I2V").
924922 """
925923 import numpy as np
926924
@@ -953,8 +951,6 @@ def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base, split_s
953951 cached_residual ,
954952 skip_warmup ,
955953 mag_ratios ,
956- split_step ,
957- model_type ,
958954 )
959955
960956def magcache_step (
@@ -964,10 +960,6 @@ def magcache_step(
964960 magcache_thresh ,
965961 magcache_K ,
966962 skip_warmup ,
967- split_step = None ,
968- model_type = "T2V" ,
969- num_steps = None ,
970- retention_ratio = 0.2 ,
971963):
972964 """Update MagCache accumulated state and decide if to skip.
973965
@@ -978,10 +970,6 @@ def magcache_step(
978970 magcache_thresh: Error threshold.
979971 magcache_K: Max skip steps.
980972 skip_warmup: Warmup steps threshold.
981- split_step: Optional step index where the model switches (e.g., from high to low noise).
982- model_type: Pipeline type ("T2V" or "I2V").
983- num_steps: Total inference steps, used to calculate post-split warmups.
984- retention_ratio: Used to calculate post-split warmups.
985973 """
986974 import numpy as np
987975
@@ -998,16 +986,8 @@ def magcache_step(
998986 cur_mag_ratio_uncond = mag_ratios [step * 2 + 1 ]
999987
1000988 use_magcache = True
1001- if split_step is not None :
1002- if model_type == "I2V" :
1003- if step < int (split_step + (num_steps - split_step ) * retention_ratio ):
1004- use_magcache = False
1005- else :
1006- if step < int (split_step * retention_ratio ) or (step <= ((num_steps - split_step ) * retention_ratio + split_step ) and step >= split_step ):
1007- use_magcache = False
1008- else :
1009- if step < skip_warmup :
1010- use_magcache = False
989+ if step < skip_warmup :
990+ use_magcache = False
1011991
1012992 skip_blocks = False
1013993 if use_magcache :
0 commit comments