Skip to content

Commit 5f5ca6a

Browse files
committed
fix for dtypes
1 parent f5342fd commit 5f5ca6a

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __call__(
153153
current_padding = list(self._causal_padding)
154154

155155
if cache_x is not None:
156-
x_concat = jnp.concatenate([cache_x, x], axis=1)
156+
x_concat = jnp.concatenate([cache_x.astype(x.dtype), x], axis=1)
157157
new_cache = x_concat[:, -CACHE_T:, ...].astype(cache_x.dtype)
158158

159159
padding_needed = self._depth_padding_before - cache_x.shape[1]
@@ -198,14 +198,17 @@ def __init__(
198198
self.bias = nnx.Param(jnp.zeros(shape)) if use_bias else 0
199199

200200
def __call__(self, x: jax.Array) -> jax.Array:
201+
input_dtype = x.dtype
201202
normalized = jnp.linalg.norm(
202203
x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True
203204
)
204205
normalized = x / jnp.maximum(normalized, self.eps)
205206
normalized = normalized * self.scale * self.gamma
206207
if self.bias:
207-
return normalized + self.bias.value
208-
return normalized
208+
out = normalized + self.bias.value
209+
else:
210+
out = normalized
211+
return out.astype(input_dtype)
209212

210213

211214
class WanUpsample(nnx.Module):
@@ -385,7 +388,7 @@ def __call__(
385388

386389
elif self.mode == "upsample3d":
387390
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
388-
new_cache["time_conv"] = tc_cache
391+
new_cache["time_conv"] = tc_cache.astype(cache["time_conv"].dtype)
389392

390393
b, t, h, w, c = x.shape
391394
x = x.reshape(b, t, h, w, 2, c // 2)
@@ -413,7 +416,7 @@ def __call__(
413416
x = x.reshape(b, t, h_new, w_new, c_new)
414417

415418
prev_cache = cache.get("time_conv")
416-
x_combined = jnp.concatenate([prev_cache, x], axis=1)
419+
x_combined = jnp.concatenate([prev_cache.astype(x.dtype), x], axis=1)
417420
x, _ = self.time_conv(x_combined, cache_x=None)
418421
new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...].astype(prev_cache.dtype)
419422

0 commit comments

Comments
 (0)