We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ae22683 commit 3f752cbCopy full SHA for 3f752cb
1 file changed
src/maxdiffusion/models/wan/autoencoder_kl_wan.py
@@ -1206,7 +1206,7 @@ def _decode(
1206
fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...]
1207
axis = 1 if fm1.shape[0] > 1 else 0
1208
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
1209
- out_1 = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
+ out_1 = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)
1210
1211
out_list = [out_0, out_1]
1212
@@ -1226,7 +1226,7 @@ def scan_fn(carry, chunk_in):
1226
fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...]
1227
1228
1229
- new_chunk = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
+ new_chunk = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)
1230
1231
return next_feat_map, new_chunk
1232
0 commit comments