We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 63bf3e9 commit a996cceCopy full SHA for a996cce
1 file changed
src/maxdiffusion/models/wan/autoencoder_kl_wan.py
@@ -154,7 +154,7 @@ def __call__(
154
155
if cache_x is not None:
156
x_concat = jnp.concatenate([cache_x, x], axis=1)
157
- new_cache = x_concat[:, -CACHE_T:, ...]
+ new_cache = x_concat[:, -CACHE_T:, ...].astype(cache_x.dtype)
158
159
padding_needed = self._depth_padding_before - cache_x.shape[1]
160
if padding_needed < 0:
@@ -415,7 +415,7 @@ def __call__(
415
prev_cache = cache.get("time_conv")
416
x_combined = jnp.concatenate([prev_cache, x], axis=1)
417
x, _ = self.time_conv(x_combined, cache_x=None)
418
- new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...]
+ new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...].astype(prev_cache.dtype)
419
420
else:
421
if hasattr(self, "resample"):
0 commit comments