@@ -1126,44 +1126,51 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11261126 t = x .shape [1 ]
11271127 enc_feat_map = feat_cache ._enc_feat_map
11281128
1129- # 1. Evaluate the first frame manually to establish the initial cache with JAX Arrays.
1130- # This prevents jax.lax.scan from crashing on type mismatch between None and ShapedArray.
1131- out_0 , enc_feat_map , _ = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = 0 )
1132- out = out_0
1133-
1134- # 2. Evaluate the second chunk (4 frames) manually to stabilize WanCausalConv3d caches to T=2.
1135- # WanCausalConv3d uses cache_x = x[:, -2:]. After 1 frame, cache is T=1. After 4 frames, it stabilizes to T=2.
1136- if t > 1 :
1137- out_1 , enc_feat_map , _ = self .encoder (x [:, 1 :5 , :, :, :], feat_cache = enc_feat_map , feat_idx = 0 )
1138- out = jnp .concatenate ([out_0 , out_1 ], axis = 1 )
1139-
1140- # 3. Process remaining frames in chunks of 4 using jax.lax.scan
1141- if t > 5 :
1142- x_rest = x [:, 5 :, :, :, :]
1143- B , T_rest , H , W , C = x_rest .shape
1144- num_chunks = T_rest // 4
1145-
1146- # Reshape to (B, num_chunks, 4, H, W, C)
1147- x_chunks = jnp .reshape (x_rest , (B , num_chunks , 4 , H , W , C ))
1148-
1149- # Swap axes for scan traversal: (num_chunks, B, 4, H, W, C)
1150- x_scan = jnp .swapaxes (x_chunks , 0 , 1 )
1151-
1152- def scan_fn (carry_cache , input_chunk ):
1153- # input_chunk shape: (B, 4, H, W, C)
1154- out_chunk , new_cache , _ = self .encoder (input_chunk , feat_cache = carry_cache , feat_idx = 0 )
1155- # out_chunk shape: (B, 1, H', W', C')
1156- return new_cache , out_chunk
1129+ @nnx .jit
1130+ def encode_sequence (encoder , x_seq , current_enc_feat_map ):
1131+ t_seq = x_seq .shape [1 ]
1132+ # 1. Evaluate the first frame manually to establish the initial cache with JAX Arrays.
1133+ # This prevents jax.lax.scan from crashing on type mismatch between None and ShapedArray.
1134+ out_0 , current_enc_feat_map , _ = encoder (x_seq [:, :1 , :, :, :], feat_cache = current_enc_feat_map , feat_idx = 0 )
1135+ out_seq = out_0
1136+
1137+ # 2. Evaluate the second chunk (4 frames) manually to stabilize WanCausalConv3d caches to T=2.
1138+ # WanCausalConv3d uses cache_x = x[:, -2:]. After 1 frame, cache is T=1. After 4 frames, it stabilizes to T=2.
1139+ if t_seq > 1 :
1140+ out_1 , current_enc_feat_map , _ = encoder (x_seq [:, 1 :5 , :, :, :], feat_cache = current_enc_feat_map , feat_idx = 0 )
1141+ out_seq = jnp .concatenate ([out_0 , out_1 ], axis = 1 )
1142+
1143+ # 3. Process remaining frames in chunks of 4 using jax.lax.scan
1144+ if t_seq > 5 :
1145+ x_rest = x_seq [:, 5 :, :, :, :]
1146+ B , T_rest , H , W , C = x_rest .shape
1147+ num_chunks = T_rest // 4
1148+
1149+ # Reshape to (B, num_chunks, 4, H, W, C)
1150+ x_chunks = jnp .reshape (x_rest , (B , num_chunks , 4 , H , W , C ))
11571151
1158- enc_feat_map , scanned_out_chunks = jax .lax .scan (scan_fn , enc_feat_map , x_scan )
1159-
1160- # scanned_out_chunks shape: (num_chunks, B, 1, H', W', C')
1161- scanned_out_chunks = jnp .swapaxes (scanned_out_chunks , 0 , 1 )
1162-
1163- B_out , _ , _ , H_out , W_out , C_out = scanned_out_chunks .shape
1164- scanned_out_chunks = jnp .reshape (scanned_out_chunks , (B_out , num_chunks , H_out , W_out , C_out ))
1152+ # Swap axes for scan traversal: (num_chunks, B, 4, H, W, C)
1153+ x_scan = jnp .swapaxes (x_chunks , 0 , 1 )
1154+
1155+ def scan_fn (carry_cache , input_chunk ):
1156+ # input_chunk shape: (B, 4, H, W, C)
1157+ out_chunk , new_cache , _ = encoder (input_chunk , feat_cache = carry_cache , feat_idx = 0 )
1158+ # out_chunk shape: (B, 1, H', W', C')
1159+ return new_cache , out_chunk
1160+
1161+ current_enc_feat_map , scanned_out_chunks = jax .lax .scan (scan_fn , current_enc_feat_map , x_scan )
1162+
1163+ # scanned_out_chunks shape: (num_chunks, B, 1, H', W', C')
1164+ scanned_out_chunks = jnp .swapaxes (scanned_out_chunks , 0 , 1 )
1165+
1166+ B_out , _ , _ , H_out , W_out , C_out = scanned_out_chunks .shape
1167+ scanned_out_chunks = jnp .reshape (scanned_out_chunks , (B_out , num_chunks , H_out , W_out , C_out ))
1168+
1169+ out_seq = jnp .concatenate ([out_seq , scanned_out_chunks ], axis = 1 )
11651170
1166- out = jnp .concatenate ([out , scanned_out_chunks ], axis = 1 )
1171+ return out_seq , current_enc_feat_map
1172+
1173+ out , enc_feat_map = encode_sequence (self .encoder , x , enc_feat_map )
11671174
11681175 # 3. Update back to the wrapper object if needed
11691176 feat_cache ._enc_feat_map = enc_feat_map
@@ -1193,57 +1200,64 @@ def _decode(
11931200
11941201 dec_feat_map = feat_cache ._feat_map
11951202
1196- # 1. Evaluate the first frame manually (Cache: None -> RepSentinel/ShapedArray)
1197- # The decoder returns 1 frame on the first step.
1198- out_0 , dec_feat_map , _ = self .decoder (x [:, 0 :1 , :, :, :], feat_cache = dec_feat_map , feat_idx = 0 )
1199- out = out_0
1200-
1201- # 2. Evaluate the second frame manually (Cache: RepSentinel -> ShapedArray)
1202- # This ensures that ALL cache components are ShapredArrays before entering jax.lax.scan,
1203- # preventing TraceContext errors due to type mismatches.
1204- if iter_ > 1 :
1205- out_1 , dec_feat_map , _ = self .decoder (x [:, 1 :2 , :, :, :], feat_cache = dec_feat_map , feat_idx = 0 )
1203+ @nnx .jit
1204+ def decode_sequence (decoder , x_seq , current_dec_feat_map ):
1205+ iter_s = x_seq .shape [1 ]
12061206
1207- # Bypass an issue where frame[1] should be frame[2] and vice versa.
1208- fm1 = out_1 [:, 0 :1 , ...]
1209- fm2 = out_1 [:, 1 :2 , ...]
1210- fm3 = out_1 [:, 2 :3 , ...]
1211- fm4 = out_1 [:, 3 :4 , ...]
1212- out_1_fixed = jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1213- out = jnp .concatenate ([out_0 , out_1_fixed ], axis = 1 )
1214-
1215- # 3. Process remaining frames using jax.lax.scan (requires homogenous output and carry shapes)
1216- if iter_ > 2 :
1217- x_rest = x [:, 2 :, :, :, :]
1218- x_scan = jnp .swapaxes (x_rest , 0 , 1 ) # (T-2, B, H, W, C)
1219-
1220- def scan_fn (carry_cache , input_frame ):
1221- input_frame = jnp .expand_dims (input_frame , 1 ) # (B, 1, H, W, C)
1222- out_frames , new_cache , _ = self .decoder (input_frame , feat_cache = carry_cache , feat_idx = 0 )
1223-
1224- # Bypass an issue where frame[1] should be frame[2] and vice versa.
1225- fm1 = out_frames [:, 0 :1 , ...]
1226- fm2 = out_frames [:, 1 :2 , ...]
1227- fm3 = out_frames [:, 2 :3 , ...]
1228- fm4 = out_frames [:, 3 :4 , ...]
1207+ # 1. Evaluate the first frame manually (Cache: None -> RepSentinel/ShapedArray)
1208+ # The decoder returns 1 frame on the first step.
1209+ out_0 , current_dec_feat_map , _ = decoder (x_seq [:, 0 :1 , :, :, :], feat_cache = current_dec_feat_map , feat_idx = 0 )
1210+ out_seq = out_0
1211+
1212+ # 2. Evaluate the second frame manually (Cache: RepSentinel -> ShapedArray)
1213+ # This ensures that ALL cache components are ShapredArrays before entering jax.lax.scan,
1214+ # preventing TraceContext errors due to type mismatches.
1215+ if iter_s > 1 :
1216+ out_1 , current_dec_feat_map , _ = decoder (x_seq [:, 1 :2 , :, :, :], feat_cache = current_dec_feat_map , feat_idx = 0 )
12291217
1230- fixed_out_frames = jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1231- return new_cache , fixed_out_frames
1232-
1233- dec_feat_map , scanned_out_frames = jax .lax .scan (scan_fn , dec_feat_map , x_scan )
1234-
1235- # scanned_out_frames is (T-2, B, 4, H, W, C)
1236- B = scanned_out_frames .shape [1 ]
1237- T_minus_2 = scanned_out_frames .shape [0 ]
1238- H , W , C = scanned_out_frames .shape [3 ], scanned_out_frames .shape [4 ], scanned_out_frames .shape [5 ]
1239-
1240- # Swap back to (B, T-2, 4, H, W, C)
1241- scanned_out_frames = jnp .swapaxes (scanned_out_frames , 0 , 1 )
1242- # Flatten the temporal axes to (B, (T-2)*4, H, W, C)
1243- scanned_out_frames = jnp .reshape (scanned_out_frames , (B , T_minus_2 * 4 , H , W , C ))
1244-
1245- out = jnp .concatenate ([out , scanned_out_frames ], axis = 1 )
1246-
1218+ # Bypass an issue where frame[1] should be frame[2] and vice versa.
1219+ fm1 = out_1 [:, 0 :1 , ...]
1220+ fm2 = out_1 [:, 1 :2 , ...]
1221+ fm3 = out_1 [:, 2 :3 , ...]
1222+ fm4 = out_1 [:, 3 :4 , ...]
1223+ out_1_fixed = jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1224+ out_seq = jnp .concatenate ([out_0 , out_1_fixed ], axis = 1 )
1225+
1226+ # 3. Process remaining frames using jax.lax.scan (requires homogenous output and carry shapes)
1227+ if iter_s > 2 :
1228+ x_rest = x_seq [:, 2 :, :, :, :]
1229+ x_scan = jnp .swapaxes (x_rest , 0 , 1 ) # (T-2, B, H, W, C)
1230+
1231+ def scan_fn (carry_cache , input_frame ):
1232+ input_frame = jnp .expand_dims (input_frame , 1 ) # (B, 1, H, W, C)
1233+ out_frames , new_cache , _ = decoder (input_frame , feat_cache = carry_cache , feat_idx = 0 )
1234+
1235+ # Bypass an issue where frame[1] should be frame[2] and vice versa.
1236+ fm1 = out_frames [:, 0 :1 , ...]
1237+ fm2 = out_frames [:, 1 :2 , ...]
1238+ fm3 = out_frames [:, 2 :3 , ...]
1239+ fm4 = out_frames [:, 3 :4 , ...]
1240+
1241+ fixed_out_frames = jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1242+ return new_cache , fixed_out_frames
1243+
1244+ current_dec_feat_map , scanned_out_frames = jax .lax .scan (scan_fn , current_dec_feat_map , x_scan )
1245+
1246+ # scanned_out_frames is (T-2, B, 4, H, W, C)
1247+ B = scanned_out_frames .shape [1 ]
1248+ T_minus_2 = scanned_out_frames .shape [0 ]
1249+ H , W , C = scanned_out_frames .shape [3 ], scanned_out_frames .shape [4 ], scanned_out_frames .shape [5 ]
1250+
1251+ # Swap back to (B, T-2, 4, H, W, C)
1252+ scanned_out_frames = jnp .swapaxes (scanned_out_frames , 0 , 1 )
1253+ # Flatten the temporal axes to (B, (T-2)*4, H, W, C)
1254+ scanned_out_frames = jnp .reshape (scanned_out_frames , (B , T_minus_2 * 4 , H , W , C ))
1255+
1256+ out_seq = jnp .concatenate ([out_seq , scanned_out_frames ], axis = 1 )
1257+
1258+ return out_seq , current_dec_feat_map
1259+
1260+ out , dec_feat_map = decode_sequence (self .decoder , x , dec_feat_map )
12471261 feat_cache ._feat_map = dec_feat_map
12481262
12491263 out = jnp .clip (out , min = - 1.0 , max = 1.0 )
0 commit comments