@@ -1131,9 +1131,15 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11311131 out_0 , enc_feat_map , _ = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = 0 )
11321132 out = out_0
11331133
1134- # 2. Process remaining frames in chunks of 4 using jax.lax.scan
1134+ # 2. Evaluate the second chunk (4 frames) manually to stabilize WanCausalConv3d caches to T=2.
1135+ # WanCausalConv3d uses cache_x = x[:, -2:]. After 1 frame, cache is T=1. After 4 frames, it stabilizes to T=2.
11351136 if t > 1 :
1136- x_rest = x [:, 1 :, :, :, :]
1137+ out_1 , enc_feat_map , _ = self .encoder (x [:, 1 :5 , :, :, :], feat_cache = enc_feat_map , feat_idx = 0 )
1138+ out = jnp .concatenate ([out_0 , out_1 ], axis = 1 )
1139+
1140+ # 3. Process remaining frames in chunks of 4 using jax.lax.scan
1141+ if t > 5 :
1142+ x_rest = x [:, 5 :, :, :, :]
11371143 B , T_rest , H , W , C = x_rest .shape
11381144 num_chunks = T_rest // 4
11391145
@@ -1157,7 +1163,7 @@ def scan_fn(carry_cache, input_chunk):
11571163 B_out , _ , _ , H_out , W_out , C_out = scanned_out_chunks .shape
11581164 scanned_out_chunks = jnp .reshape (scanned_out_chunks , (B_out , num_chunks , H_out , W_out , C_out ))
11591165
1160- out = jnp .concatenate ([out_0 , scanned_out_chunks ], axis = 1 )
1166+ out = jnp .concatenate ([out , scanned_out_chunks ], axis = 1 )
11611167
11621168 # 3. Update back to the wrapper object if needed
11631169 feat_cache ._enc_feat_map = enc_feat_map
0 commit comments