Skip to content

Commit bef52f5

Browse files
committed
Trying fix for vid distortion
1 parent fee7d3a commit bef52f5

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
369369
if feat_cache[idx] is None:
370370
feat_cache = _update_cache(feat_cache, idx, jnp.copy(x))
371371
feat_idx += 1
372-
x = self.time_conv(x)
372+
# Pad with 2 zeros to satisfy kernel size 3 and stride 2 for a single frame
373+
x_padded = jnp.pad(x, ((0, 0), (2, 0), (0, 0), (0, 0), (0, 0)))
374+
x = self.time_conv(x_padded)
373375
else:
374376
cache_x = jnp.copy(x[:, -1:, :, :, :])
375377
x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
@@ -382,7 +384,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
382384

383385
return x, feat_cache, feat_idx
384386

385-
386387
class WanResidualBlock(nnx.Module):
387388

388389
def __init__(

0 commit comments

Comments
 (0)