Skip to content

Commit c0aeca7

Browse files
committed
fix for dtypes
1 parent 5213ea5 commit c0aeca7

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,10 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
511511
input_dtype = x.dtype
512512

513513
h, sc_cache = self.conv_shortcut(x, cache.get("shortcut"))
514-
new_cache["shortcut"] = sc_cache.astype(self.dtype)
514+
if sc_cache is not None:
515+
new_cache["shortcut"] = sc_cache.astype(self.dtype)
516+
else:
517+
new_cache["shortcut"] = None
515518

516519
x = self.norm1(x)
517520
x = self.nonlinearity(x)

0 commit comments

Comments
 (0)