Skip to content

Commit 7c72f90

Browse files
committed
fix for dtypes
1 parent 938a18b commit 7c72f90

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def initialize_cache(self, batch_size, height, width, dtype):
136136
def __call__(
137137
self, x: jax.Array, cache_x: Optional[jax.Array] = None
138138
) -> Tuple[jax.Array, jax.Array]:
139+
x = x.astype(self.dtype)
139140
# OPTIMIZATION: Spatial Partitioning during execution
140141
if self.mesh is not None and "fsdp" in self.mesh.axis_names:
141142
height = x.shape[2]
@@ -155,7 +156,7 @@ def __call__(
155156

156157
if cache_x is not None:
157158
x_concat = jnp.concatenate([cache_x.astype(x.dtype), x], axis=1)
158-
new_cache = x_concat[:, -CACHE_T:, ...].astype(cache_x.dtype)
159+
new_cache = x_concat[:, -CACHE_T:, ...].astype(self.dtype)
159160

160161
padding_needed = self._depth_padding_before - cache_x.shape[1]
161162
if padding_needed < 0:
@@ -165,7 +166,7 @@ def __call__(
165166
x_input = x_concat
166167
current_padding[1] = (padding_needed, 0)
167168
else:
168-
new_cache = x[:, -CACHE_T:, ...].astype(x.dtype)
169+
new_cache = x[:, -CACHE_T:, ...].astype(self.dtype)
169170
x_input = x
170171

171172
padding_to_apply = tuple(current_padding)
@@ -376,6 +377,7 @@ def initialize_cache(self, batch_size, height, width, dtype):
376377
def __call__(
377378
self, x: jax.Array, cache: Dict[str, Any] = None
378379
) -> Tuple[jax.Array, Dict[str, Any]]:
380+
x = x.astype(self.dtype)
379381
if cache is None:
380382
cache = {}
381383
new_cache = {}
@@ -389,7 +391,7 @@ def __call__(
389391

390392
elif self.mode == "upsample3d":
391393
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
392-
new_cache["time_conv"] = tc_cache.astype(cache["time_conv"].dtype)
394+
new_cache["time_conv"] = tc_cache.astype(self.dtype)
393395

394396
b, t, h, w, c = x.shape
395397
x = x.reshape(b, t, h, w, 2, c // 2)
@@ -419,7 +421,7 @@ def __call__(
419421
prev_cache = cache.get("time_conv")
420422
x_combined = jnp.concatenate([prev_cache.astype(x.dtype), x], axis=1)
421423
x, _ = self.time_conv(x_combined, cache_x=None)
422-
new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...].astype(prev_cache.dtype)
424+
new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...].astype(self.dtype)
423425

424426
else:
425427
if hasattr(self, "resample"):
@@ -522,7 +524,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
522524
x, c2 = self.conv2(x, cache.get("conv2"))
523525
new_cache["conv2"] = c2
524526

525-
x = (x + h).astype(input_dtype)
527+
x = (x + h).astype(self.dtype)
526528
return x, new_cache
527529

528530

0 commit comments

Comments
 (0)