@@ -71,6 +71,7 @@ def __init__(
7171 padding , 3 , "padding"
7272 ) # (D, H, W) padding amounts
7373 self .mesh = mesh
74+ self .dtype = dtype
7475
7576 self ._causal_padding = (
7677 (0 , 0 ), # Batch dimension - no padding
@@ -154,7 +155,7 @@ def __call__(
154155
155156 if cache_x is not None :
156157 x_concat = jnp .concatenate ([cache_x .astype (x .dtype ), x ], axis = 1 )
157- new_cache = x_concat [:, - CACHE_T :, ...].astype (cache_x .dtype )
158+ new_cache = x_concat [:, - CACHE_T :, ...].astype (self .dtype )
158159
159160 padding_needed = self ._depth_padding_before - cache_x .shape [1 ]
160161 if padding_needed < 0 :
@@ -164,7 +165,7 @@ def __call__(
164165 x_input = x_concat
165166 current_padding [1 ] = (padding_needed , 0 )
166167 else :
167- new_cache = x [:, - CACHE_T :, ...]
168+ new_cache = x [:, - CACHE_T :, ...]. astype ( self . dtype )
168169 x_input = x
169170
170171 padding_to_apply = tuple (current_padding )
@@ -443,6 +444,7 @@ def __init__(
443444 weights_dtype : jnp .dtype = jnp .float32 ,
444445 precision : jax .lax .Precision = None ,
445446 ):
447+ self .dtype = dtype
446448 self .nonlinearity = get_activation (non_linearity )
447449 self .norm1 = WanRMS_norm (
448450 dim = in_dim , rngs = rngs , images = False , channel_first = False
@@ -520,8 +522,8 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
520522 x , c2 = self .conv2 (x , cache .get ("conv2" ))
521523 new_cache ["conv2" ] = c2
522524
523- out = (x + h ).astype (input_dtype )
524- return out , new_cache
525+ x = (x + h ).astype (self . dtype )
526+ return x , new_cache
525527
526528
527529class WanAttentionBlock (nnx .Module ):
@@ -535,6 +537,7 @@ def __init__(
535537 precision : jax .lax .Precision = None ,
536538 ):
537539 self .dim = dim
540+ self .dtype = dtype
538541 self .norm = WanRMS_norm (rngs = rngs , dim = dim , channel_first = False )
539542 self .to_qkv = nnx .Conv (
540543 in_features = dim ,
@@ -578,7 +581,7 @@ def __call__(self, x: jax.Array):
578581 x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
579582 x = self .proj (x )
580583 x = x .reshape (batch_size , time , height , width , channels )
581- out = (x + identity ).astype (input_dtype )
584+ out = (x + identity ).astype (self . dtype )
582585 return out
583586
584587
0 commit comments