Skip to content

Commit a7b4952

Browse files
committed
fix for dtypes
1 parent 40501f8 commit a7b4952

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
padding, 3, "padding"
7272
) # (D, H, W) padding amounts
7373
self.mesh = mesh
74+
self.dtype = dtype
7475

7576
self._causal_padding = (
7677
(0, 0), # Batch dimension - no padding
@@ -154,7 +155,7 @@ def __call__(
154155

155156
if cache_x is not None:
156157
x_concat = jnp.concatenate([cache_x.astype(x.dtype), x], axis=1)
157-
new_cache = x_concat[:, -CACHE_T:, ...].astype(cache_x.dtype)
158+
new_cache = x_concat[:, -CACHE_T:, ...].astype(self.dtype)
158159

159160
padding_needed = self._depth_padding_before - cache_x.shape[1]
160161
if padding_needed < 0:
@@ -164,7 +165,7 @@ def __call__(
164165
x_input = x_concat
165166
current_padding[1] = (padding_needed, 0)
166167
else:
167-
new_cache = x[:, -CACHE_T:, ...]
168+
new_cache = x[:, -CACHE_T:, ...].astype(self.dtype)
168169
x_input = x
169170

170171
padding_to_apply = tuple(current_padding)
@@ -443,6 +444,7 @@ def __init__(
443444
weights_dtype: jnp.dtype = jnp.float32,
444445
precision: jax.lax.Precision = None,
445446
):
447+
self.dtype = dtype
446448
self.nonlinearity = get_activation(non_linearity)
447449
self.norm1 = WanRMS_norm(
448450
dim=in_dim, rngs=rngs, images=False, channel_first=False
@@ -520,8 +522,8 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
520522
x, c2 = self.conv2(x, cache.get("conv2"))
521523
new_cache["conv2"] = c2
522524

523-
out = (x + h).astype(input_dtype)
524-
return out, new_cache
525+
x = (x + h).astype(self.dtype)
526+
return x, new_cache
525527

526528

527529
class WanAttentionBlock(nnx.Module):
@@ -535,6 +537,7 @@ def __init__(
535537
precision: jax.lax.Precision = None,
536538
):
537539
self.dim = dim
540+
self.dtype = dtype
538541
self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False)
539542
self.to_qkv = nnx.Conv(
540543
in_features=dim,
@@ -578,7 +581,7 @@ def __call__(self, x: jax.Array):
578581
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
579582
x = self.proj(x)
580583
x = x.reshape(batch_size, time, height, width, channels)
581-
out = (x + identity).astype(input_dtype)
584+
out = (x + identity).astype(self.dtype)
582585
return out
583586

584587

0 commit comments

Comments
 (0)