@@ -360,7 +360,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
360360 feat_cache = _update_cache (feat_cache , idx , cache_x )
361361 feat_idx += 1
362362 x = x .reshape (b , t , h , w , 2 , c )
363- x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
363+ # x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
364+ x = x .transpose (0 , 1 , 4 , 2 , 3 , 5 )
364365 x = x .reshape (b , t * 2 , h , w , c )
365366 t = x .shape [1 ]
366367 x = x .reshape (b * t , h , w , c )
@@ -1160,23 +1161,7 @@ def _decode(
11601161 out , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
11611162 else :
11621163 out_ , dec_feat_map , conv_idx = self .decoder (x [:, i : i + 1 , :, :, :], feat_cache = dec_feat_map , feat_idx = conv_idx )
1163-
1164- # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1165- # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1166- # Most likely due to an incorrect reshaping in the decoder.
1167- fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
1168- # When batch_size is 0, expand batch dim for concatenation
1169- # else, expand frame dim for concatenation so that batch dim stays intact.
1170- axis = 0
1171- if fm1 .shape [0 ] > 1 :
1172- axis = 1
1173-
1174- if len (fm1 .shape ) == 4 :
1175- fm1 = jnp .expand_dims (fm1 , axis = axis )
1176- fm2 = jnp .expand_dims (fm2 , axis = axis )
1177- fm3 = jnp .expand_dims (fm3 , axis = axis )
1178- fm4 = jnp .expand_dims (fm4 , axis = axis )
1179- out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1164+ out = jnp .concatenate ([out , out_ ], axis = 1 )
11801165
11811166 feat_cache ._feat_map = dec_feat_map
11821167
0 commit comments