Skip to content

Commit 693bf4c

Browse files
committed
removed frame reordering
1 parent c247d99 commit 693bf4c

1 file changed

Lines changed: 5 additions & 8 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,19 +1190,16 @@ def _decode(
11901190
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
11911191
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
11921192
# Most likely due to an incorrect reshaping in the decoder.
1193-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1193+
# fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
11941194
# When batch_size is 0, expand batch dim for concatenation
11951195
# else, expand frame dim for concatenation so that batch dim stays intact.
11961196
axis = 0
1197-
if fm1.shape[0] > 1:
1197+
if out_.shape[0] > 1:
11981198
axis = 1
11991199

1200-
if len(fm1.shape) == 4:
1201-
fm1 = jnp.expand_dims(fm1, axis=axis)
1202-
fm2 = jnp.expand_dims(fm2, axis=axis)
1203-
fm3 = jnp.expand_dims(fm3, axis=axis)
1204-
fm4 = jnp.expand_dims(fm4, axis=axis)
1205-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1200+
if len(out_.shape) == 4:
1201+
out_ = jnp.expand_dims(out_, axis=axis)
1202+
out = jnp.concatenate([out, out_], axis=1)
12061203

12071204
feat_cache._feat_map = dec_feat_map
12081205

0 commit comments

Comments
 (0)