Skip to content

Commit fee7d3a

Browse files
committed
Trying fix for vid distortion
1 parent 0d4588a commit fee7d3a

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,16 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
369369
if feat_cache[idx] is None:
370370
feat_cache = _update_cache(feat_cache, idx, jnp.copy(x))
371371
feat_idx += 1
372+
x = self.time_conv(x)
372373
else:
373374
cache_x = jnp.copy(x[:, -1:, :, :, :])
374375
x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
376+
# Discard the first frame of output as it corresponds to the cached frame (already processed)
377+
x = x[:, 1:, ...]
375378
feat_cache = _update_cache(feat_cache, idx, cache_x)
376379
feat_idx += 1
380+
else:
381+
x = self.time_conv(x)
377382

378383
return x, feat_cache, feat_idx
379384

0 commit comments

Comments
 (0)