Skip to content

Commit 902b6db

Browse files
committed
fix for dtypes
1 parent c0aeca7 commit 902b6db

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -661,13 +661,13 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
661661
new_cache = {"resnets": []}
662662

663663
x, c = self.resnets[0](x, cache.get("resnets", [None])[0])
664-
new_cache["resnets"].append(c.astype(self.dtype))
664+
new_cache["resnets"].append(c)
665665

666666
for i, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
667667
if attn is not None:
668668
x = attn(x)
669669
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i + 1])
670-
new_cache["resnets"].append(c.astype(self.dtype))
670+
new_cache["resnets"].append(c)
671671

672672
return x, new_cache
673673

@@ -741,11 +741,11 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
741741

742742
for i, resnet in enumerate(self.resnets):
743743
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i])
744-
new_cache["resnets"].append(c.astype(self.dtype))
744+
new_cache["resnets"].append(c)
745745

746746
if self.upsamplers:
747747
x, c = self.upsamplers[0](x, cache.get("upsamplers", [None])[0])
748-
new_cache["upsamplers"].append(c.astype(self.dtype))
748+
new_cache["upsamplers"].append(c)
749749
return x, new_cache
750750

751751

@@ -899,7 +899,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
899899
for i, layer in enumerate(self.down_blocks):
900900
if isinstance(layer, (WanResidualBlock, WanResample)):
901901
x, c = layer(x, current_down_caches[i])
902-
new_cache["down_blocks"].append(c.astype(self.dtype))
902+
new_cache["down_blocks"].append(c)
903903
else:
904904
x = layer(x)
905905
new_cache["down_blocks"].append(None)
@@ -1038,7 +1038,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
10381038
current_up_caches = cache.get("up_blocks", [None] * len(self.up_blocks))
10391039
for i, up_block in enumerate(self.up_blocks):
10401040
x, c = up_block(x, current_up_caches[i])
1041-
new_cache["up_blocks"].append(c.astype(self.dtype))
1041+
new_cache["up_blocks"].append(c)
10421042

10431043
x = self.norm_out(x)
10441044
x = self.nonlinearity(x)

0 commit comments

Comments
 (0)