@@ -503,6 +503,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
503503 if cache is None :
504504 cache = {}
505505 new_cache = {}
506+ input_dtype = x .dtype
506507
507508 h , sc_cache = self .conv_shortcut (x , cache .get ("shortcut" ))
508509 new_cache ["shortcut" ] = sc_cache
@@ -519,8 +520,8 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
519520 x , c2 = self .conv2 (x , cache .get ("conv2" ))
520521 new_cache ["conv2" ] = c2
521522
522- x = x + h
523- return x , new_cache
523+ out = ( x + h ). astype ( input_dtype )
524+ return out , new_cache
524525
525526
526527class WanAttentionBlock (nnx .Module ):
@@ -562,6 +563,7 @@ def __init__(
562563
563564 def __call__ (self , x : jax .Array ):
564565 identity = x
566+ input_dtype = x .dtype
565567 batch_size , time , height , width , channels = x .shape
566568 x = x .reshape (batch_size * time , height , width , channels )
567569 x = self .norm (x )
@@ -576,7 +578,8 @@ def __call__(self, x: jax.Array):
576578 x = jnp .squeeze (x , 1 ).reshape (batch_size * time , height , width , channels )
577579 x = self .proj (x )
578580 x = x .reshape (batch_size , time , height , width , channels )
579- return x + identity
581+ out = (x + identity ).astype (input_dtype )
582+ return out
580583
581584
582585class WanMidBlock (nnx .Module ):
0 commit comments