@@ -1190,19 +1190,16 @@ def _decode(
11901190 # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
11911191 # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
11921192 # Most likely due to an incorrect reshaping in the decoder.
1193- fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
1193+ # fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
11941194 # When batch_size is 0, expand batch dim for concatenation
11951195 # else, expand frame dim for concatenation so that batch dim stays intact.
11961196 axis = 0
1197- if fm1 .shape [0 ] > 1 :
1197+ if out_ .shape [0 ] > 1 :
11981198 axis = 1
11991199
1200- if len (fm1 .shape ) == 4 :
1201- fm1 = jnp .expand_dims (fm1 , axis = axis )
1202- fm2 = jnp .expand_dims (fm2 , axis = axis )
1203- fm3 = jnp .expand_dims (fm3 , axis = axis )
1204- fm4 = jnp .expand_dims (fm4 , axis = axis )
1205- out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1200+ if len (out_ .shape ) == 4 :
1201+ out_ = jnp .expand_dims (out_ , axis = axis )
1202+ out = jnp .concatenate ([out , out_ ], axis = 1 )
12061203
12071204 feat_cache ._feat_map = dec_feat_map
12081205
0 commit comments