@@ -155,7 +155,7 @@ def __call__(
155155
156156 if cache_x is not None :
157157 x_concat = jnp .concatenate ([cache_x .astype (x .dtype ), x ], axis = 1 )
158- new_cache = x_concat [:, - CACHE_T :, ...].astype (self .dtype )
158+ new_cache = x_concat [:, - CACHE_T :, ...].astype (cache_x .dtype )
159159
160160 padding_needed = self ._depth_padding_before - cache_x .shape [1 ]
161161 if padding_needed < 0 :
@@ -165,7 +165,7 @@ def __call__(
165165 x_input = x_concat
166166 current_padding [1 ] = (padding_needed , 0 )
167167 else :
168- new_cache = x [:, - CACHE_T :, ...].astype (self .dtype )
168+ new_cache = x [:, - CACHE_T :, ...].astype (x .dtype )
169169 x_input = x
170170
171171 padding_to_apply = tuple (current_padding )
@@ -522,7 +522,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
522522 x , c2 = self .conv2 (x , cache .get ("conv2" ))
523523 new_cache ["conv2" ] = c2
524524
525- x = (x + h ).astype (self . dtype )
525+ x = (x + h ).astype (input_dtype )
526526 return x , new_cache
527527
528528
@@ -581,7 +581,7 @@ def __call__(self, x: jax.Array):
581581 x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
582582 x = self .proj (x )
583583 x = x .reshape (batch_size , time , height , width , channels )
584- out = (x + identity ).astype (self . dtype )
584+ out = (x + identity ).astype (input_dtype )
585585 return out
586586
587587
0 commit comments