@@ -775,7 +775,7 @@ def transformer_forward_pass(
775775 cached_residual = cached_residual ,
776776 return_residual = return_residual ,
777777 )
778-
778+
779779 if return_residual :
780780 noise_pred , residual_x = outputs
781781 else :
@@ -899,56 +899,61 @@ def transformer_forward_pass_cfg_cache(
899899 noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx )
900900 return noise_pred_merged , noise_cond
901901
902+
902903def nearest_interp (src , target_len ):
903- """Nearest neighbor interpolation for ratio scaling layout."""
904- src_len = len (src )
905- if target_len == 1 :
906- import numpy as np
907- return np .array ([src [- 1 ]])
904+ """Nearest neighbor interpolation for ratio scaling layout."""
905+ src_len = len (src )
906+ if target_len == 1 :
908907 import numpy as np
909- indices = np .round (np .linspace (0 , src_len - 1 , target_len )).astype (np .int32 )
910- return src [indices ]
908+
909+ return np .array ([src [- 1 ]])
910+ import numpy as np
911+
912+ indices = np .round (np .linspace (0 , src_len - 1 , target_len )).astype (np .int32 )
913+ return src [indices ]
914+
911915
912916def init_magcache (num_inference_steps , retention_ratio , mag_ratios_base ):
913- """Initialize MagCache variables and interpolate ratios.
914-
915- Args:
916- num_inference_steps: Number of inference steps.
917- retention_ratio: Retention ratio of unchanged steps.
918- mag_ratios_base: Base magnitude ratios array or list.
919- """
920- import numpy as np
921-
922- accumulated_ratio_cond = 1.0
923- accumulated_ratio_uncond = 1.0
924- accumulated_err_cond = 0.0
925- accumulated_err_uncond = 0.0
926- accumulated_steps_cond = 0
927- accumulated_steps_uncond = 0
928- cached_residual = None
929-
930- skip_warmup = int (num_inference_steps * retention_ratio )
931-
932- mag_ratios_base = np .array (mag_ratios_base )
933-
934- if len (mag_ratios_base ) != num_inference_steps * 2 :
935- mag_cond = nearest_interp (mag_ratios_base [0 ::2 ], num_inference_steps )
936- mag_uncond = nearest_interp (mag_ratios_base [1 ::2 ], num_inference_steps )
937- mag_ratios = np .concatenate ([mag_cond .reshape (- 1 , 1 ), mag_uncond .reshape (- 1 , 1 )], axis = 1 ).reshape (- 1 )
938- else :
939- mag_ratios = mag_ratios_base
940-
941- return (
942- accumulated_ratio_cond ,
943- accumulated_ratio_uncond ,
944- accumulated_err_cond ,
945- accumulated_err_uncond ,
946- accumulated_steps_cond ,
947- accumulated_steps_uncond ,
948- cached_residual ,
949- skip_warmup ,
950- mag_ratios ,
951- )
917+ """Initialize MagCache variables and interpolate ratios.
918+
919+ Args:
920+ num_inference_steps: Number of inference steps.
921+ retention_ratio: Retention ratio of unchanged steps.
922+ mag_ratios_base: Base magnitude ratios array or list.
923+ """
924+ import numpy as np
925+
926+ accumulated_ratio_cond = 1.0
927+ accumulated_ratio_uncond = 1.0
928+ accumulated_err_cond = 0.0
929+ accumulated_err_uncond = 0.0
930+ accumulated_steps_cond = 0
931+ accumulated_steps_uncond = 0
932+ cached_residual = None
933+
934+ skip_warmup = int (num_inference_steps * retention_ratio )
935+
936+ mag_ratios_base = np .array (mag_ratios_base )
937+
938+ if len (mag_ratios_base ) != num_inference_steps * 2 :
939+ mag_cond = nearest_interp (mag_ratios_base [0 ::2 ], num_inference_steps )
940+ mag_uncond = nearest_interp (mag_ratios_base [1 ::2 ], num_inference_steps )
941+ mag_ratios = np .concatenate ([mag_cond .reshape (- 1 , 1 ), mag_uncond .reshape (- 1 , 1 )], axis = 1 ).reshape (- 1 )
942+ else :
943+ mag_ratios = mag_ratios_base
944+
945+ return (
946+ accumulated_ratio_cond ,
947+ accumulated_ratio_uncond ,
948+ accumulated_err_cond ,
949+ accumulated_err_uncond ,
950+ accumulated_steps_cond ,
951+ accumulated_steps_uncond ,
952+ cached_residual ,
953+ skip_warmup ,
954+ mag_ratios ,
955+ )
956+
952957
953958def magcache_step (
954959 step ,
@@ -959,71 +964,71 @@ def magcache_step(
959964 skip_warmup = 0 ,
960965 use_magcache = None ,
961966):
962- """Update MagCache accumulated state and decide if to skip.
963-
964- Args:
965- step: Current inference step.
966- mag_ratios: Interpolated magnitude ratios array.
967- accumulated_state: Tuple containing accumulated variables.
968- magcache_thresh: Error threshold.
969- magcache_K: Max skip steps.
970- skip_warmup: Warmup steps threshold.
971- use_magcache: Optional manual override boolean to enable/disable cache for this step.
972- """
973- import numpy as np
974-
975- (
976- accumulated_ratio_cond ,
977- accumulated_ratio_uncond ,
978- accumulated_err_cond ,
979- accumulated_err_uncond ,
980- accumulated_steps_cond ,
981- accumulated_steps_uncond ,
982- ) = accumulated_state
983-
984- cur_mag_ratio_cond = mag_ratios [step * 2 ]
985- cur_mag_ratio_uncond = mag_ratios [step * 2 + 1 ]
986-
987- if use_magcache is None :
988- use_magcache = True
989- if step < skip_warmup :
990- use_magcache = False
991-
992- skip_blocks = False
993- if use_magcache :
994- new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
995- new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
996-
997- err_cond = np .abs (1.0 - new_ratio_cond )
998- err_uncond = np .abs (1.0 - new_ratio_uncond )
999-
1000- if (
1001- accumulated_err_cond + err_cond < magcache_thresh
1002- and accumulated_steps_cond < magcache_K
1003- and accumulated_err_uncond + err_uncond < magcache_thresh
1004- and accumulated_steps_uncond < magcache_K
1005- ):
1006- skip_blocks = True
1007- accumulated_ratio_cond = new_ratio_cond
1008- accumulated_ratio_uncond = new_ratio_uncond
1009- accumulated_err_cond += err_cond
1010- accumulated_err_uncond += err_uncond
1011- accumulated_steps_cond += 1
1012- accumulated_steps_uncond += 1
1013- else :
1014- accumulated_ratio_cond = 1.0
1015- accumulated_ratio_uncond = 1.0
1016- accumulated_err_cond = 0.0
1017- accumulated_err_uncond = 0.0
1018- accumulated_steps_cond = 0
1019- accumulated_steps_uncond = 0
1020-
1021- new_state = (
1022- accumulated_ratio_cond ,
1023- accumulated_ratio_uncond ,
1024- accumulated_err_cond ,
1025- accumulated_err_uncond ,
1026- accumulated_steps_cond ,
1027- accumulated_steps_uncond ,
1028- )
1029- return skip_blocks , new_state
967+ """Update MagCache accumulated state and decide if to skip.
968+
969+ Args:
970+ step: Current inference step.
971+ mag_ratios: Interpolated magnitude ratios array.
972+ accumulated_state: Tuple containing accumulated variables.
973+ magcache_thresh: Error threshold.
974+ magcache_K: Max skip steps.
975+ skip_warmup: Warmup steps threshold.
976+ use_magcache: Optional manual override boolean to enable/disable cache for this step.
977+ """
978+ import numpy as np
979+
980+ (
981+ accumulated_ratio_cond ,
982+ accumulated_ratio_uncond ,
983+ accumulated_err_cond ,
984+ accumulated_err_uncond ,
985+ accumulated_steps_cond ,
986+ accumulated_steps_uncond ,
987+ ) = accumulated_state
988+
989+ cur_mag_ratio_cond = mag_ratios [step * 2 ]
990+ cur_mag_ratio_uncond = mag_ratios [step * 2 + 1 ]
991+
992+ if use_magcache is None :
993+ use_magcache = True
994+ if step < skip_warmup :
995+ use_magcache = False
996+
997+ skip_blocks = False
998+ if use_magcache :
999+ new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
1000+ new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
1001+
1002+ err_cond = np .abs (1.0 - new_ratio_cond )
1003+ err_uncond = np .abs (1.0 - new_ratio_uncond )
1004+
1005+ if (
1006+ accumulated_err_cond + err_cond < magcache_thresh
1007+ and accumulated_steps_cond < magcache_K
1008+ and accumulated_err_uncond + err_uncond < magcache_thresh
1009+ and accumulated_steps_uncond < magcache_K
1010+ ):
1011+ skip_blocks = True
1012+ accumulated_ratio_cond = new_ratio_cond
1013+ accumulated_ratio_uncond = new_ratio_uncond
1014+ accumulated_err_cond += err_cond
1015+ accumulated_err_uncond += err_uncond
1016+ accumulated_steps_cond += 1
1017+ accumulated_steps_uncond += 1
1018+ else :
1019+ accumulated_ratio_cond = 1.0
1020+ accumulated_ratio_uncond = 1.0
1021+ accumulated_err_cond = 0.0
1022+ accumulated_err_uncond = 0.0
1023+ accumulated_steps_cond = 0
1024+ accumulated_steps_uncond = 0
1025+
1026+ new_state = (
1027+ accumulated_ratio_cond ,
1028+ accumulated_ratio_uncond ,
1029+ accumulated_err_cond ,
1030+ accumulated_err_uncond ,
1031+ accumulated_steps_cond ,
1032+ accumulated_steps_uncond ,
1033+ )
1034+ return skip_blocks , new_state
0 commit comments