@@ -153,7 +153,7 @@ def __call__(
153153 current_padding = list (self ._causal_padding )
154154
155155 if cache_x is not None :
156- x_concat = jnp .concatenate ([cache_x , x ], axis = 1 )
156+ x_concat = jnp .concatenate ([cache_x . astype ( x . dtype ) , x ], axis = 1 )
157157 new_cache = x_concat [:, - CACHE_T :, ...].astype (cache_x .dtype )
158158
159159 padding_needed = self ._depth_padding_before - cache_x .shape [1 ]
@@ -198,14 +198,17 @@ def __init__(
198198 self .bias = nnx .Param (jnp .zeros (shape )) if use_bias else 0
199199
200200 def __call__ (self , x : jax .Array ) -> jax .Array :
201+ input_dtype = x .dtype
201202 normalized = jnp .linalg .norm (
202203 x , ord = 2 , axis = (1 if self .channel_first else - 1 ), keepdims = True
203204 )
204205 normalized = x / jnp .maximum (normalized , self .eps )
205206 normalized = normalized * self .scale * self .gamma
206207 if self .bias :
207- return normalized + self .bias .value
208- return normalized
208+ out = normalized + self .bias .value
209+ else :
210+ out = normalized
211+ return out .astype (input_dtype )
209212
210213
211214class WanUpsample (nnx .Module ):
@@ -385,7 +388,7 @@ def __call__(
385388
386389 elif self .mode == "upsample3d" :
387390 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
388- new_cache ["time_conv" ] = tc_cache
391+ new_cache ["time_conv" ] = tc_cache . astype ( cache [ "time_conv" ]. dtype )
389392
390393 b , t , h , w , c = x .shape
391394 x = x .reshape (b , t , h , w , 2 , c // 2 )
@@ -413,7 +416,7 @@ def __call__(
413416 x = x .reshape (b , t , h_new , w_new , c_new )
414417
415418 prev_cache = cache .get ("time_conv" )
416- x_combined = jnp .concatenate ([prev_cache , x ], axis = 1 )
419+ x_combined = jnp .concatenate ([prev_cache . astype ( x . dtype ) , x ], axis = 1 )
417420 x , _ = self .time_conv (x_combined , cache_x = None )
418421 new_cache ["time_conv" ] = x_combined [:, - CACHE_T :, ...].astype (prev_cache .dtype )
419422
0 commit comments