@@ -1128,17 +1128,36 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11281128 iter_ = 1 + (t - 1 ) // 4
11291129 enc_feat_map = feat_cache ._enc_feat_map
11301130
1131- for i in range (iter_ ):
1132- enc_conv_idx = 0
1133- if i == 0 :
1134- out , enc_feat_map , enc_conv_idx = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = enc_conv_idx )
1135- else :
1136- out_ , enc_feat_map , enc_conv_idx = self .encoder (
1137- x [:, 1 + 4 * (i - 1 ) : 1 + 4 * i , :, :, :],
1138- feat_cache = enc_feat_map ,
1139- feat_idx = enc_conv_idx ,
1140- )
1141- out = jnp .concatenate ([out , out_ ], axis = 1 )
1131+ # Process first chunk explicitly
1132+ out_first , enc_feat_map , _ = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = 0 )
1133+
1134+ # Prepare remaining chunks for scan
1135+ def scan_body_encode (carry , x_chunk ):
1136+ feat_map = carry
1137+ out_chunk , updated_feat_map , _ = self .encoder (x_chunk , feat_cache = feat_map , feat_idx = 0 )
1138+ return updated_feat_map , out_chunk
1139+
1140+ if iter_ > 1 :
1141+ # We have remaining chunks to process. Let's reshape/stack them.
1142+ # x is (B, T, H, W, C) where T = 1 + 4 * (iter_ - 1)
1143+ # We want to scan over the iter_-1 blocks of size 4.
1144+ x_rest = x [:, 1 :, :, :, :]
1145+ b , t_rest , h , w , c = x_rest .shape
1146+ x_rest_blocks = x_rest .reshape (b , iter_ - 1 , 4 , h , w , c )
1147+ # scan over the blocks dimension (axis=1) -> swap axis 0 and 1
1148+ x_scan_input = jnp .swapaxes (x_rest_blocks , 0 , 1 ) # shape: (iter_ - 1, B, 4, H, W, C)
1149+
1150+ enc_feat_map , out_rest_stacked = jax .lax .scan (scan_body_encode , enc_feat_map , x_scan_input )
1151+ # out_rest_stacked shape: (iter_ - 1, B, T_out_chunk, H_out, W_out, C_out)
1152+
1153+ # Transpose back and flatten the iteration and time dimensions
1154+ out_rest_stacked = jnp .swapaxes (out_rest_stacked , 0 , 1 )
1155+ b_out , iters_out , t_out_chunk , h_out , w_out , c_out = out_rest_stacked .shape
1156+ out_rest = out_rest_stacked .reshape (b_out , iters_out * t_out_chunk , h_out , w_out , c_out )
1157+
1158+ out = jnp .concatenate ([out_first , out_rest ], axis = 1 )
1159+ else :
1160+ out = out_first
11421161
11431162 enc = self .quant_conv (out )
11441163 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
@@ -1166,29 +1185,45 @@ def _decode(
11661185
11671186 dec_feat_map = feat_cache ._feat_map
11681187
1169- for i in range (iter_ ):
1170- conv_idx = 0
1171- if i == 0 :
1172- out , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
1173- else :
1174- out_ , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
1175-
1176- # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1177- # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1178- # Most likely due to an incorrect reshaping in the decoder.
1179- fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
1180- # When batch_size is 0, expand batch dim for concatenation
1181- # else, expand frame dim for concatenation so that batch dim stays intact.
1182- axis = 0
1183- if fm1 .shape [0 ] > 1 :
1184- axis = 1
1185-
1186- if len (fm1 .shape ) == 4 :
1187- fm1 = jnp .expand_dims (fm1 , axis = axis )
1188- fm2 = jnp .expand_dims (fm2 , axis = axis )
1189- fm3 = jnp .expand_dims (fm3 , axis = axis )
1190- fm4 = jnp .expand_dims (fm4 , axis = axis )
1191- out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1188+ def process_out_frame (out_ ):
1189+ # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1190+ # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1191+ # Most likely due to an incorrect reshaping in the decoder.
1192+ fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
1193+ axis = 0
1194+ if fm1 .shape [0 ] > 1 :
1195+ axis = 1
1196+
1197+ if len (fm1 .shape ) == 4 :
1198+ fm1 = jnp .expand_dims (fm1 , axis = axis )
1199+ fm2 = jnp .expand_dims (fm2 , axis = axis )
1200+ fm3 = jnp .expand_dims (fm3 , axis = axis )
1201+ fm4 = jnp .expand_dims (fm4 , axis = axis )
1202+ return jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1203+
1204+ # Process first chunk explicitly
1205+ out_first , dec_feat_map , _ = self .decoder (x [:, :1 , :, :, :], feat_cache = dec_feat_map , feat_idx = 0 )
1206+
1207+ def scan_body_decode (carry , x_chunk ):
1208+ feat_map = carry
1209+ out_chunk , updated_feat_map , _ = self .decoder (x_chunk , feat_cache = feat_map , feat_idx = 0 )
1210+ out_processed = process_out_frame (out_chunk )
1211+ return updated_feat_map , out_processed
1212+
1213+ if iter_ > 1 :
1214+ x_rest = x [:, 1 :, :, :, :]
1215+ # Scan over the time dimension directly
1216+ x_scan_input = jnp .swapaxes (jnp .expand_dims (x_rest , axis = 2 ), 0 , 1 ) # shape: (iter_ - 1, B, 1, H, W, C)
1217+
1218+ dec_feat_map , out_rest_stacked = jax .lax .scan (scan_body_decode , dec_feat_map , x_scan_input )
1219+
1220+ out_rest_stacked = jnp .swapaxes (out_rest_stacked , 0 , 1 )
1221+ b_out , iters_out , t_out_frames , h_out , w_out , c_out = out_rest_stacked .shape
1222+ out_rest = out_rest_stacked .reshape (b_out , iters_out * t_out_frames , h_out , w_out , c_out )
1223+
1224+ out = jnp .concatenate ([out_first , out_rest ], axis = 1 )
1225+ else :
1226+ out = out_first
11921227
11931228 out = jnp .clip (out , min = - 1.0 , max = 1.0 )
11941229 return out
0 commit comments