@@ -1115,20 +1115,28 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11151115 iter_ = 1 + (t - 1 ) // 4
11161116 enc_feat_map = feat_cache ._enc_feat_map
11171117
1118+ encoder_chunks = []
11181119 for i in range (iter_ ):
11191120 enc_conv_idx = 0
11201121 if i == 0 :
11211122 out , enc_feat_map , enc_conv_idx = self .encoder (x [:, :1 , :, :, :], feat_cache = enc_feat_map , feat_idx = enc_conv_idx )
1123+ encoder_chunks .append (out )
11221124 else :
11231125 out_ , enc_feat_map , enc_conv_idx = self .encoder (
11241126 x [:, 1 + 4 * (i - 1 ) : 1 + 4 * i , :, :, :],
11251127 feat_cache = enc_feat_map ,
11261128 feat_idx = enc_conv_idx ,
11271129 )
11281130 start_concat = time .time ()
1129- out = jnp . concatenate ([ out , out_ ], axis = 1 )
1130- out .block_until_ready ()
1131+ encoder_chunks . append ( out_ )
1132+ out_ .block_until_ready ()
11311133 print (f"Encode step { i } concat time: { time .time () - start_concat } " )
1134+ # out = jnp.concatenate([out, out_], axis=1) # Removed quadratic concat
1135+ # out.block_until_ready()
1136+ # print(f"Encode step {i} concat time: {time.time() - start_concat}")
1137+
1138+ # Final concatenation
1139+ out = jnp .concatenate (encoder_chunks , axis = 1 )
11321140
11331141 # Update back to the wrapper object if needed, but for result we use local vars
11341142 feat_cache ._enc_feat_map = enc_feat_map
@@ -1158,33 +1166,29 @@ def _decode(
11581166
11591167 dec_feat_map = feat_cache ._feat_map
11601168
1169+ decoder_chunks = []
11611170 for i in range (iter_ ):
11621171 conv_idx = 0
11631172 if i == 0 :
11641173 out , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
1174+ decoder_chunks .append (out )
11651175 else :
11661176 out_ , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
11671177
1168- # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1169- # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1170- # Most likely due to an incorrect reshaping in the decoder.
1178+ # Reorder frames [0, 1, 2, 3] -> [0, 2, 1, 3] using advanced indexing
1179+ # This replaces the splitting, expanding, and concatenating logic.
1180+ # Original: fm1(0), fm2(1), fm3(2), fm4(3) -> concat(out, fm1, fm3, fm2, fm4)
1181+ # Sequence: prev, 0, 2, 1, 3
1182+ # We append the reordered chunk to the list.
11711183 start_expand_concat = time .time ()
1172- fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
1173- # When batch_size is 0, expand batch dim for concatenation
1174- # else, expand frame dim for concatenation so that batch dim stays intact.
1175- axis = 0
1176- if fm1 .shape [0 ] > 1 :
1177- axis = 1
1178-
1179- if len (fm1 .shape ) == 4 :
1180- fm1 = jnp .expand_dims (fm1 , axis = axis )
1181- fm2 = jnp .expand_dims (fm2 , axis = axis )
1182- fm3 = jnp .expand_dims (fm3 , axis = axis )
1183- fm4 = jnp .expand_dims (fm4 , axis = axis )
1184- out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1185- out .block_until_ready ()
1184+ out_chunk = out_ [:, [0 , 2 , 1 , 3 ], :, :, :]
1185+ decoder_chunks .append (out_chunk )
1186+ out_chunk .block_until_ready ()
11861187 print (f"Decode step { i } expand+concat time: { time .time () - start_expand_concat } " )
11871188
1189+ feat_cache ._feat_map = dec_feat_map
1190+ out = jnp .concatenate (decoder_chunks , axis = 1 )
1191+
11881192 feat_cache ._feat_map = dec_feat_map
11891193
11901194 out = jnp .clip (out , min = - 1.0 , max = 1.0 )
0 commit comments