@@ -96,12 +96,13 @@ def __call__(
9696 magcache_K : Optional [int ] = None ,
9797 retention_ratio : Optional [float ] = None ,
9898 ):
99+ config = getattr (self , "config" , None )
99100 if magcache_thresh is None :
100- magcache_thresh = getattr (self . config , "magcache_thresh" , 0.12 )
101+ magcache_thresh = getattr (config , "magcache_thresh" , 0.12 )
101102 if magcache_K is None :
102- magcache_K = getattr (self . config , "magcache_K" , 2 )
103+ magcache_K = getattr (config , "magcache_K" , 2 )
103104 if retention_ratio is None :
104- retention_ratio = getattr (self . config , "retention_ratio" , 0.2 )
105+ retention_ratio = getattr (config , "retention_ratio" , 0.2 )
105106
106107 if use_cfg_cache and guidance_scale <= 1.0 :
107108 raise ValueError (
@@ -138,7 +139,7 @@ def __call__(
138139 magcache_K = magcache_K ,
139140 retention_ratio = retention_ratio ,
140141 height = height ,
141- mag_ratios_base = self . config . mag_ratios_base if hasattr ( self . config , "mag_ratios_base" ) else None ,
142+ mag_ratios_base = getattr ( config , "mag_ratios_base" , None ) ,
142143 )
143144
144145 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -244,43 +245,23 @@ def run_inference_2_1(
244245 cached_noise_uncond = None
245246
246247 if use_magcache and do_cfg :
247- (
248- accumulated_ratio_cond ,
249- accumulated_ratio_uncond ,
250- accumulated_err_cond ,
251- accumulated_err_uncond ,
252- accumulated_steps_cond ,
253- accumulated_steps_uncond ,
254- cached_residual ,
255- skip_warmup ,
256- mag_ratios ,
257- ) = init_magcache (num_inference_steps , retention_ratio , mag_ratios_base )
258-
259- for step in range (num_inference_steps ):
260- t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
248+ magcache_init = init_magcache (num_inference_steps , retention_ratio , mag_ratios_base )
249+ accumulated_state = magcache_init [:6 ]
250+ cached_residual = magcache_init [6 ]
251+ skip_warmup = magcache_init [7 ]
252+ mag_ratios = magcache_init [8 ]
253+
254+ for step in range (num_inference_steps ):
255+ t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
256+
257+ if use_magcache and do_cfg :
261258 timestep = jnp .broadcast_to (t , bsz * 2 if do_cfg else bsz )
262259
263- accumulated_state = (
264- accumulated_ratio_cond ,
265- accumulated_ratio_uncond ,
266- accumulated_err_cond ,
267- accumulated_err_uncond ,
268- accumulated_steps_cond ,
269- accumulated_steps_uncond ,
270- )
271260 skip_blocks , accumulated_state = magcache_step (
272261 step , mag_ratios , accumulated_state , magcache_thresh , magcache_K , skip_warmup
273262 )
274- (
275- accumulated_ratio_cond ,
276- accumulated_ratio_uncond ,
277- accumulated_err_cond ,
278- accumulated_err_uncond ,
279- accumulated_steps_cond ,
280- accumulated_steps_uncond ,
281- ) = accumulated_state
282-
283- outputs = transformer_forward_pass (
263+
264+ noise_pred , latents , residual_x_cur = transformer_forward_pass (
284265 graphdef ,
285266 sharded_state ,
286267 rest_of_state ,
@@ -294,18 +275,10 @@ def run_inference_2_1(
294275 return_residual = True ,
295276 )
296277
297- noise_pred , latents_returned , residual_x_cur = outputs
298-
299278 if not skip_blocks :
300279 cached_residual = residual_x_cur
301280
302- latents = latents_returned
303- latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
304- return latents
305-
306- else :
307- for step in range (num_inference_steps ):
308- t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
281+ else :
309282 is_cache_step = step_is_cache [step ]
310283
311284 if is_cache_step :
@@ -351,5 +324,6 @@ def run_inference_2_1(
351324 guidance_scale = guidance_scale ,
352325 )
353326
354- latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
355- return latents
327+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
328+
329+ return latents
0 commit comments