Skip to content

Commit 3f752cb

Browse files
committed
Fixing VAE decoding issue
1 parent ae22683 commit 3f752cb

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def _decode(
12061206
fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...]
12071207
axis = 1 if fm1.shape[0] > 1 else 0
12081208
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
1209-
out_1 = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1209+
out_1 = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)
12101210

12111211
out_list = [out_0, out_1]
12121212

@@ -1226,7 +1226,7 @@ def scan_fn(carry, chunk_in):
12261226
fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...]
12271227
axis = 1 if fm1.shape[0] > 1 else 0
12281228
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
1229-
new_chunk = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1229+
new_chunk = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)
12301230

12311231
return next_feat_map, new_chunk
12321232

0 commit comments

Comments
 (0)