Skip to content

Commit 40501f8

Browse files
committed
fix for dtypes
1 parent 5f5ca6a commit 40501f8

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
503503
if cache is None:
504504
cache = {}
505505
new_cache = {}
506+
input_dtype = x.dtype
506507

507508
h, sc_cache = self.conv_shortcut(x, cache.get("shortcut"))
508509
new_cache["shortcut"] = sc_cache
@@ -519,8 +520,8 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
519520
x, c2 = self.conv2(x, cache.get("conv2"))
520521
new_cache["conv2"] = c2
521522

522-
x = x + h
523-
return x, new_cache
523+
out = (x + h).astype(input_dtype)
524+
return out, new_cache
524525

525526

526527
class WanAttentionBlock(nnx.Module):
@@ -562,6 +563,7 @@ def __init__(
562563

563564
def __call__(self, x: jax.Array):
564565
identity = x
566+
input_dtype = x.dtype
565567
batch_size, time, height, width, channels = x.shape
566568
x = x.reshape(batch_size * time, height, width, channels)
567569
x = self.norm(x)
@@ -576,7 +578,8 @@ def __call__(self, x: jax.Array):
576578
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
577579
x = self.proj(x)
578580
x = x.reshape(batch_size, time, height, width, channels)
579-
return x + identity
581+
out = (x + identity).astype(input_dtype)
582+
return out
580583

581584

582585
class WanMidBlock(nnx.Module):

0 commit comments

Comments
 (0)