@@ -156,7 +156,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
156156 shard_width_axis = "context"
157157
158158 x_padded = jax .lax .with_sharding_constraint (
159- x_padded , jax .sharding .PartitionSpec (None , None , shard_axis , shard_width_axis , None )
159+ x_padded , jax .sharding .PartitionSpec ("data" , None , shard_axis , shard_width_axis , None )
160160 )
161161
162162 out = self .conv (x_padded )
@@ -1125,24 +1125,27 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11251125 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
11261126 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
11271127
1128- t = x . shape [ 1 ]
1129- iter_ = 1 + ( t - 1 ) // 4
1128+ # Swap to (T, B, H, W, C) for scanning over time
1129+ x_scan = jnp . swapaxes ( x , 0 , 1 )
11301130 enc_feat_map = feat_cache ._enc_feat_map
11311131
1132- for i in range (iter_ ):
1132+ def scan_fn (carry_cache , input_frame ):
1133+ # Expand time dimension to 1 for the encoder
1134+ input_frame = jnp .expand_dims (input_frame , 1 )
1135+ # idx is restarted at 0 for each chunk/frame conceptually
11331136 enc_conv_idx = 0
1134- if i == 0 :
1135- out , enc_feat_map , enc_conv_idx = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = enc_conv_idx )
1136- else :
1137- out_ , enc_feat_map , enc_conv_idx = self .encoder (
1138- x [:, 1 + 4 * (i - 1 ) : 1 + 4 * i , :, :, :],
1139- feat_cache = enc_feat_map ,
1140- feat_idx = enc_conv_idx ,
1141- )
1142- out = jnp .concatenate ([out , out_ ], axis = 1 )
1137+ out_frame , new_cache , _ = self .encoder (input_frame , feat_cache = carry_cache , feat_idx = enc_conv_idx )
1138+ out_frame = jnp .squeeze (out_frame , 1 )
1139+ return new_cache , out_frame
11431140
1144- # Update back to the wrapper object if needed, but for result we use local vars
1145- feat_cache ._enc_feat_map = enc_feat_map
1141+ # Perform JAX scan
1142+ final_enc_feat_map , encoded_frames = jax .lax .scan (scan_fn , enc_feat_map , x_scan )
1143+
1144+ # Swap back to (B, T, ... )
1145+ out = jnp .swapaxes (encoded_frames , 0 , 1 )
1146+
1147+ # Update back to the wrapper object if needed
1148+ feat_cache ._enc_feat_map = final_enc_feat_map
11461149
11471150 enc = self .quant_conv (out )
11481151 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
@@ -1169,29 +1172,43 @@ def _decode(
11691172
11701173 dec_feat_map = feat_cache ._feat_map
11711174
1172- for i in range (iter_ ):
1173- conv_idx = 0
1174- if i == 0 :
1175- out , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
1176- else :
1177- out_ , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
1178-
1179- # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1180- # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1181- # Most likely due to an incorrect reshaping in the decoder.
1182- fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
1183- # When batch_size is 0, expand batch dim for concatenation
1184- # else, expand frame dim for concatenation so that batch dim stays intact.
1185- axis = 0
1186- if fm1 .shape [0 ] > 1 :
1187- axis = 1
1188-
1189- if len (fm1 .shape ) == 4 :
1190- fm1 = jnp .expand_dims (fm1 , axis = axis )
1191- fm2 = jnp .expand_dims (fm2 , axis = axis )
1192- fm3 = jnp .expand_dims (fm3 , axis = axis )
1193- fm4 = jnp .expand_dims (fm4 , axis = axis )
1194- out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1175+ # Evaluate the first frame manually to establish the initial cache.
1176+ # The decoder returns 1 frame on the first step, and 4 frames on subsequent steps due to temporal upsampling.
1177+ out_0 , dec_feat_map , _ = self .decoder (x [:, 0 :1 , :, :, :], feat_cache = dec_feat_map , feat_idx = 0 )
1178+ out = out_0
1179+
1180+ # Process remaining frames using jax.lax.scan (requires homogenous output shapes)
1181+ if iter_ > 1 :
1182+ x_rest = x [:, 1 :, :, :, :]
1183+ x_scan = jnp .swapaxes (x_rest , 0 , 1 ) # (T-1, B, H, W, C)
1184+
1185+ def scan_fn (carry_cache , input_frame ):
1186+ input_frame = jnp .expand_dims (input_frame , 1 ) # (B, 1, H, W, C)
1187+ out_frames , new_cache , _ = self .decoder (input_frame , feat_cache = carry_cache , feat_idx = 0 )
1188+
1189+ # Bypass an issue where frame[1] should be frame[2] and vice versa.
1190+ # Ensure dimensionality allows straightforward slicing:
1191+ fm1 = out_frames [:, 0 :1 , ...]
1192+ fm2 = out_frames [:, 1 :2 , ...]
1193+ fm3 = out_frames [:, 2 :3 , ...]
1194+ fm4 = out_frames [:, 3 :4 , ...]
1195+
1196+ fixed_out_frames = jnp .concatenate ([fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1197+ return new_cache , fixed_out_frames
1198+
1199+ dec_feat_map , scanned_out_frames = jax .lax .scan (scan_fn , dec_feat_map , x_scan )
1200+
1201+ # scanned_out_frames is (T-1, B, 4, H, W, C)
1202+ B = scanned_out_frames .shape [1 ]
1203+ T_minus_1 = scanned_out_frames .shape [0 ]
1204+ H , W , C = scanned_out_frames .shape [3 ], scanned_out_frames .shape [4 ], scanned_out_frames .shape [5 ]
1205+
1206+ # Swap back to (B, T-1, 4, H, W, C)
1207+ scanned_out_frames = jnp .swapaxes (scanned_out_frames , 0 , 1 )
1208+ # Flatten the temporal axes to (B, (T-1)*4, H, W, C)
1209+ scanned_out_frames = jnp .reshape (scanned_out_frames , (B , T_minus_1 * 4 , H , W , C ))
1210+
1211+ out = jnp .concatenate ([out_0 , scanned_out_frames ], axis = 1 )
11951212
11961213 feat_cache ._feat_map = dec_feat_map
11971214
0 commit comments