Skip to content

Commit 15f9c27

Browse files
committed
fix
1 parent 2ada7ec commit 15f9c27

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,9 +1131,15 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11311131
out_0, enc_feat_map, _ = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=0)
11321132
out = out_0
11331133

1134-
# 2. Process remaining frames in chunks of 4 using jax.lax.scan
1134+
# 2. Evaluate the second chunk (4 frames) manually to stabilize WanCausalConv3d caches to T=2.
1135+
# WanCausalConv3d uses cache_x = x[:, -2:]. After 1 frame, cache is T=1. After 4 frames, it stabilizes to T=2.
11351136
if t > 1:
1136-
x_rest = x[:, 1:, :, :, :]
1137+
out_1, enc_feat_map, _ = self.encoder(x[:, 1:5, :, :, :], feat_cache=enc_feat_map, feat_idx=0)
1138+
out = jnp.concatenate([out_0, out_1], axis=1)
1139+
1140+
# 3. Process remaining frames in chunks of 4 using jax.lax.scan
1141+
if t > 5:
1142+
x_rest = x[:, 5:, :, :, :]
11371143
B, T_rest, H, W, C = x_rest.shape
11381144
num_chunks = T_rest // 4
11391145

@@ -1157,7 +1163,7 @@ def scan_fn(carry_cache, input_chunk):
11571163
B_out, _, _, H_out, W_out, C_out = scanned_out_chunks.shape
11581164
scanned_out_chunks = jnp.reshape(scanned_out_chunks, (B_out, num_chunks, H_out, W_out, C_out))
11591165

1160-
out = jnp.concatenate([out_0, scanned_out_chunks], axis=1)
1166+
out = jnp.concatenate([out, scanned_out_chunks], axis=1)
11611167

11621168
# 3. Update back to the wrapper object if needed
11631169
feat_cache._enc_feat_map = enc_feat_map

0 commit comments

Comments
 (0)