Skip to content

Commit 63bf3e9

Browse files
committed
fix in wanresample
1 parent 88b86c4 commit 63bf3e9

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,10 @@ def __call__(
412412
h_new, w_new, c_new = x.shape[1:]
413413
x = x.reshape(b, t, h_new, w_new, c_new)
414414

415-
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
416-
new_cache["time_conv"] = tc_cache
415+
prev_cache = cache.get("time_conv")
416+
x_combined = jnp.concatenate([prev_cache, x], axis=1)
417+
x, _ = self.time_conv(x_combined, cache_x=None)
418+
new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...]
417419

418420
else:
419421
if hasattr(self, "resample"):

0 commit comments

Comments
 (0)