@@ -136,6 +136,7 @@ def initialize_cache(self, batch_size, height, width, dtype):
136136 def __call__ (
137137 self , x : jax .Array , cache_x : Optional [jax .Array ] = None
138138 ) -> Tuple [jax .Array , jax .Array ]:
139+ x = x .astype (self .dtype )
139140 # OPTIMIZATION: Spatial Partitioning during execution
140141 if self .mesh is not None and "fsdp" in self .mesh .axis_names :
141142 height = x .shape [2 ]
@@ -155,7 +156,7 @@ def __call__(
155156
156157 if cache_x is not None :
157158 x_concat = jnp .concatenate ([cache_x .astype (x .dtype ), x ], axis = 1 )
158- new_cache = x_concat [:, - CACHE_T :, ...].astype (cache_x .dtype )
159+ new_cache = x_concat [:, - CACHE_T :, ...].astype (self .dtype )
159160
160161 padding_needed = self ._depth_padding_before - cache_x .shape [1 ]
161162 if padding_needed < 0 :
@@ -165,7 +166,7 @@ def __call__(
165166 x_input = x_concat
166167 current_padding [1 ] = (padding_needed , 0 )
167168 else :
168- new_cache = x [:, - CACHE_T :, ...].astype (x .dtype )
169+ new_cache = x [:, - CACHE_T :, ...].astype (self .dtype )
169170 x_input = x
170171
171172 padding_to_apply = tuple (current_padding )
@@ -376,6 +377,7 @@ def initialize_cache(self, batch_size, height, width, dtype):
376377 def __call__ (
377378 self , x : jax .Array , cache : Dict [str , Any ] = None
378379 ) -> Tuple [jax .Array , Dict [str , Any ]]:
380+ x = x .astype (self .dtype )
379381 if cache is None :
380382 cache = {}
381383 new_cache = {}
@@ -389,7 +391,7 @@ def __call__(
389391
390392 elif self .mode == "upsample3d" :
391393 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
392- new_cache ["time_conv" ] = tc_cache .astype (cache [ "time_conv" ] .dtype )
394+ new_cache ["time_conv" ] = tc_cache .astype (self .dtype )
393395
394396 b , t , h , w , c = x .shape
395397 x = x .reshape (b , t , h , w , 2 , c // 2 )
@@ -419,7 +421,7 @@ def __call__(
419421 prev_cache = cache .get ("time_conv" )
420422 x_combined = jnp .concatenate ([prev_cache .astype (x .dtype ), x ], axis = 1 )
421423 x , _ = self .time_conv (x_combined , cache_x = None )
422- new_cache ["time_conv" ] = x_combined [:, - CACHE_T :, ...].astype (prev_cache .dtype )
424+ new_cache ["time_conv" ] = x_combined [:, - CACHE_T :, ...].astype (self .dtype )
423425
424426 else :
425427 if hasattr (self , "resample" ):
@@ -522,7 +524,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
522524 x , c2 = self .conv2 (x , cache .get ("conv2" ))
523525 new_cache ["conv2" ] = c2
524526
525- x = (x + h ).astype (input_dtype )
527+ x = (x + h ).astype (self . dtype )
526528 return x , new_cache
527529
528530
0 commit comments