@@ -1123,27 +1123,44 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11231123 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
11241124 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
11251125
1126- # Swap to (T, B, H, W, C) for scanning over time
1127- x_scan = jnp .swapaxes (x , 0 , 1 )
1126+ t = x .shape [1 ]
11281127 enc_feat_map = feat_cache ._enc_feat_map
11291128
1130- def scan_fn (carry_cache , input_frame ):
1131- # Expand time dimension to 1 for the encoder
1132- input_frame = jnp .expand_dims (input_frame , 1 )
1133- # idx is restarted at 0 for each chunk/frame conceptually
1134- enc_conv_idx = 0
1135- out_frame , new_cache , _ = self .encoder (input_frame , feat_cache = carry_cache , feat_idx = enc_conv_idx )
1136- out_frame = jnp .squeeze (out_frame , 1 )
1137- return new_cache , out_frame
1138-
1139- # Perform JAX scan
1140- final_enc_feat_map , encoded_frames = jax .lax .scan (scan_fn , enc_feat_map , x_scan )
1141-
1142- # Swap back to (B, T, ... )
1143- out = jnp .swapaxes (encoded_frames , 0 , 1 )
1129+ # 1. Evaluate the first frame manually to establish the initial cache with JAX Arrays.
1130+ # This prevents jax.lax.scan from crashing on type mismatch between None and ShapedArray.
1131+ out_0 , enc_feat_map , _ = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = 0 )
1132+ out = out_0
11441133
1145- # Update back to the wrapper object if needed
1146- feat_cache ._enc_feat_map = final_enc_feat_map
1134+ # 2. Process remaining frames in chunks of 4 using jax.lax.scan
1135+ if t > 1 :
1136+ x_rest = x [:, 1 :, :, :, :]
1137+ B , T_rest , H , W , C = x_rest .shape
1138+ num_chunks = T_rest // 4
1139+
1140+ # Reshape to (B, num_chunks, 4, H, W, C)
1141+ x_chunks = jnp .reshape (x_rest , (B , num_chunks , 4 , H , W , C ))
1142+
1143+ # Swap axes for scan traversal: (num_chunks, B, 4, H, W, C)
1144+ x_scan = jnp .swapaxes (x_chunks , 0 , 1 )
1145+
1146+ def scan_fn (carry_cache , input_chunk ):
1147+ # input_chunk shape: (B, 4, H, W, C)
1148+ out_chunk , new_cache , _ = self .encoder (input_chunk , feat_cache = carry_cache , feat_idx = 0 )
1149+ # out_chunk shape: (B, 1, H', W', C')
1150+ return new_cache , out_chunk
1151+
1152+ enc_feat_map , scanned_out_chunks = jax .lax .scan (scan_fn , enc_feat_map , x_scan )
1153+
1154+ # scanned_out_chunks shape: (num_chunks, B, 1, H', W', C')
1155+ scanned_out_chunks = jnp .swapaxes (scanned_out_chunks , 0 , 1 )
1156+
1157+ B_out , _ , _ , H_out , W_out , C_out = scanned_out_chunks .shape
1158+ scanned_out_chunks = jnp .reshape (scanned_out_chunks , (B_out , num_chunks , H_out , W_out , C_out ))
1159+
1160+ out = jnp .concatenate ([out_0 , scanned_out_chunks ], axis = 1 )
1161+
1162+ # 3. Update back to the wrapper object if needed
1163+ feat_cache ._enc_feat_map = enc_feat_map
11471164
11481165 enc = self .quant_conv (out )
11491166 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
@@ -1170,22 +1187,35 @@ def _decode(
11701187
11711188 dec_feat_map = feat_cache ._feat_map
11721189
1173- # Evaluate the first frame manually to establish the initial cache.
1174- # The decoder returns 1 frame on the first step, and 4 frames on subsequent steps due to temporal upsampling .
1190+ # 1. Evaluate the first frame manually (Cache: None -> RepSentinel/ShapedArray)
1191+ # The decoder returns 1 frame on the first step.
11751192 out_0 , dec_feat_map , _ = self .decoder (x [:, 0 :1 , :, :, :], feat_cache = dec_feat_map , feat_idx = 0 )
11761193 out = out_0
11771194
1178- # Process remaining frames using jax.lax.scan (requires homogenous output shapes)
1195+ # 2. Evaluate the second frame manually (Cache: RepSentinel -> ShapedArray)
1196+ # This ensures that ALL cache components are ShapredArrays before entering jax.lax.scan,
1197+ # preventing TraceContext errors due to type mismatches.
11791198 if iter_ > 1 :
1180- x_rest = x [:, 1 :, :, :, :]
1181- x_scan = jnp .swapaxes (x_rest , 0 , 1 ) # (T-1, B, H, W, C)
1199+ out_1 , dec_feat_map , _ = self .decoder (x [:, 1 :2 , :, :, :], feat_cache = dec_feat_map , feat_idx = 0 )
1200+
1201+ # Bypass an issue where frame[1] should be frame[2] and vice versa.
1202+ fm1 = out_1 [:, 0 :1 , ...]
1203+ fm2 = out_1 [:, 1 :2 , ...]
1204+ fm3 = out_1 [:, 2 :3 , ...]
1205+ fm4 = out_1 [:, 3 :4 , ...]
1206+ out_1_fixed = jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1207+ out = jnp .concatenate ([out_0 , out_1_fixed ], axis = 1 )
1208+
1209+ # 3. Process remaining frames using jax.lax.scan (requires homogenous output and carry shapes)
1210+ if iter_ > 2 :
1211+ x_rest = x [:, 2 :, :, :, :]
1212+ x_scan = jnp .swapaxes (x_rest , 0 , 1 ) # (T-2, B, H, W, C)
11821213
11831214 def scan_fn (carry_cache , input_frame ):
11841215 input_frame = jnp .expand_dims (input_frame , 1 ) # (B, 1, H, W, C)
11851216 out_frames , new_cache , _ = self .decoder (input_frame , feat_cache = carry_cache , feat_idx = 0 )
11861217
11871218 # Bypass an issue where frame[1] should be frame[2] and vice versa.
1188- # Ensure dimensionality allows straightforward slicing:
11891219 fm1 = out_frames [:, 0 :1 , ...]
11901220 fm2 = out_frames [:, 1 :2 , ...]
11911221 fm3 = out_frames [:, 2 :3 , ...]
@@ -1196,17 +1226,17 @@ def scan_fn(carry_cache, input_frame):
11961226
11971227 dec_feat_map , scanned_out_frames = jax .lax .scan (scan_fn , dec_feat_map , x_scan )
11981228
1199- # scanned_out_frames is (T-1 , B, 4, H, W, C)
1229+ # scanned_out_frames is (T-2 , B, 4, H, W, C)
12001230 B = scanned_out_frames .shape [1 ]
1201- T_minus_1 = scanned_out_frames .shape [0 ]
1231+ T_minus_2 = scanned_out_frames .shape [0 ]
12021232 H , W , C = scanned_out_frames .shape [3 ], scanned_out_frames .shape [4 ], scanned_out_frames .shape [5 ]
12031233
1204- # Swap back to (B, T-1 , 4, H, W, C)
1234+ # Swap back to (B, T-2 , 4, H, W, C)
12051235 scanned_out_frames = jnp .swapaxes (scanned_out_frames , 0 , 1 )
1206- # Flatten the temporal axes to (B, (T-1 )*4, H, W, C)
1207- scanned_out_frames = jnp .reshape (scanned_out_frames , (B , T_minus_1 * 4 , H , W , C ))
1236+ # Flatten the temporal axes to (B, (T-2 )*4, H, W, C)
1237+ scanned_out_frames = jnp .reshape (scanned_out_frames , (B , T_minus_2 * 4 , H , W , C ))
12081238
1209- out = jnp .concatenate ([out_0 , scanned_out_frames ], axis = 1 )
1239+ out = jnp .concatenate ([out , scanned_out_frames ], axis = 1 )
12101240
12111241 feat_cache ._feat_map = dec_feat_map
12121242
0 commit comments