Skip to content

Commit 938a18b

Browse files
committed
fix for dtypes
1 parent a7b4952 commit 938a18b

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __call__(
155155

156156
if cache_x is not None:
157157
x_concat = jnp.concatenate([cache_x.astype(x.dtype), x], axis=1)
158-
new_cache = x_concat[:, -CACHE_T:, ...].astype(self.dtype)
158+
new_cache = x_concat[:, -CACHE_T:, ...].astype(cache_x.dtype)
159159

160160
padding_needed = self._depth_padding_before - cache_x.shape[1]
161161
if padding_needed < 0:
@@ -165,7 +165,7 @@ def __call__(
165165
x_input = x_concat
166166
current_padding[1] = (padding_needed, 0)
167167
else:
168-
new_cache = x[:, -CACHE_T:, ...].astype(self.dtype)
168+
new_cache = x[:, -CACHE_T:, ...].astype(x.dtype)
169169
x_input = x
170170

171171
padding_to_apply = tuple(current_padding)
@@ -522,7 +522,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
522522
x, c2 = self.conv2(x, cache.get("conv2"))
523523
new_cache["conv2"] = c2
524524

525-
x = (x + h).astype(self.dtype)
525+
x = (x + h).astype(input_dtype)
526526
return x, new_cache
527527

528528

@@ -581,7 +581,7 @@ def __call__(self, x: jax.Array):
581581
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
582582
x = self.proj(x)
583583
x = x.reshape(batch_size, time, height, width, channels)
584-
out = (x + identity).astype(self.dtype)
584+
out = (x + identity).astype(input_dtype)
585585
return out
586586

587587

0 commit comments

Comments
 (0)