Skip to content

Commit 50dade6

Browse files
committed
modified encode and decode
1 parent 8fbb1bc commit 50dade6

1 file changed

Lines changed: 23 additions & 19 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,20 +1115,28 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11151115
iter_ = 1 + (t - 1) // 4
11161116
enc_feat_map = feat_cache._enc_feat_map
11171117

1118+
encoder_chunks = []
11181119
for i in range(iter_):
11191120
enc_conv_idx = 0
11201121
if i == 0:
11211122
out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx)
1123+
encoder_chunks.append(out)
11221124
else:
11231125
out_, enc_feat_map, enc_conv_idx = self.encoder(
11241126
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :],
11251127
feat_cache=enc_feat_map,
11261128
feat_idx=enc_conv_idx,
11271129
)
11281130
start_concat = time.time()
1129-
out = jnp.concatenate([out, out_], axis=1)
1130-
out.block_until_ready()
1131+
encoder_chunks.append(out_)
1132+
out_.block_until_ready()
11311133
print(f"Encode step {i} concat time: {time.time() - start_concat}")
1134+
# out = jnp.concatenate([out, out_], axis=1) # Removed quadratic concat
1135+
# out.block_until_ready()
1136+
# print(f"Encode step {i} concat time: {time.time() - start_concat}")
1137+
1138+
# Final concatenation
1139+
out = jnp.concatenate(encoder_chunks, axis=1)
11321140

11331141
# Update back to the wrapper object if needed, but for result we use local vars
11341142
feat_cache._enc_feat_map = enc_feat_map
@@ -1158,33 +1166,29 @@ def _decode(
11581166

11591167
dec_feat_map = feat_cache._feat_map
11601168

1169+
decoder_chunks = []
11611170
for i in range(iter_):
11621171
conv_idx = 0
11631172
if i == 0:
11641173
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1174+
decoder_chunks.append(out)
11651175
else:
11661176
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
11671177

1168-
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1169-
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1170-
# Most likely due to an incorrect reshaping in the decoder.
1178+
# Reorder frames [0, 1, 2, 3] -> [0, 2, 1, 3] using advanced indexing
1179+
# This replaces the splitting, expanding, and concatenating logic.
1180+
# Original: fm1(0), fm2(1), fm3(2), fm4(3) -> concat(out, fm1, fm3, fm2, fm4)
1181+
# Sequence: prev, 0, 2, 1, 3
1182+
# We append the reordered chunk to the list.
11711183
start_expand_concat = time.time()
1172-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1173-
# When batch_size is 0, expand batch dim for concatenation
1174-
# else, expand frame dim for concatenation so that batch dim stays intact.
1175-
axis = 0
1176-
if fm1.shape[0] > 1:
1177-
axis = 1
1178-
1179-
if len(fm1.shape) == 4:
1180-
fm1 = jnp.expand_dims(fm1, axis=axis)
1181-
fm2 = jnp.expand_dims(fm2, axis=axis)
1182-
fm3 = jnp.expand_dims(fm3, axis=axis)
1183-
fm4 = jnp.expand_dims(fm4, axis=axis)
1184-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1185-
out.block_until_ready()
1184+
out_chunk = out_[:, [0, 2, 1, 3], :, :, :]
1185+
decoder_chunks.append(out_chunk)
1186+
out_chunk.block_until_ready()
11861187
print(f"Decode step {i} expand+concat time: {time.time() - start_expand_concat}")
11871188

1189+
feat_cache._feat_map = dec_feat_map
1190+
out = jnp.concatenate(decoder_chunks, axis=1)
1191+
11881192
feat_cache._feat_map = dec_feat_map
11891193

11901194
out = jnp.clip(out, min=-1.0, max=1.0)

0 commit comments

Comments
 (0)