@@ -367,12 +367,33 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
367367
368368 x = self .norm1 (x )
369369 x = self .nonlinearity (x )
370- x = self .conv1 (x )
370+
371+ if feat_cache is not None :
372+ idx = feat_idx [0 ]
373+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
374+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
375+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
376+
377+ x = self .conv1 (x , feat_cache [idx ])
378+ feat_cache [idx ] = cache_x
379+ feat_idx [0 ] += 1
380+ else :
381+ x = self .conv1 (x )
371382
372383 x = self .norm2 (x )
373384 x = self .nonlinearity (x )
374385 x = self .dropout (x )
375- x = self .conv2 (x )
386+
387+ if feat_cache is not None :
388+ idx = feat_idx [0 ]
389+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
390+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
391+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
392+ x = self .conv2 (x , feat_cache [idx ])
393+ feat_cache [idx ] = cache_x
394+ feat_idx [0 ] += 1
395+ else :
396+ x = self .conv2 (x )
376397
377398 return x + h
378399
@@ -442,7 +463,7 @@ def __init__(
442463 self .resnets = resnets
443464
444465 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
445- x = self .resnets [0 ](x )
466+ x = self .resnets [0 ](x , feat_cache , feat_idx )
446467 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
447468 if attn is not None :
448469 x = attn (x )
0 commit comments