1919import flax
2020import jax
2121import jax .numpy as jnp
22+ import time
2223from jax import tree_util
2324from flax import nnx
2425from ...configuration_utils import ConfigMixin
@@ -1124,7 +1125,10 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11241125 feat_cache = enc_feat_map ,
11251126 feat_idx = enc_conv_idx ,
11261127 )
1128+ start_concat = time .time ()
11271129 out = jnp .concatenate ([out , out_ ], axis = 1 )
1130+ out .block_until_ready ()
1131+ print (f"Encode step { i } concat time: { time .time () - start_concat } " )
11281132
11291133 # Update back to the wrapper object if needed, but for result we use local vars
11301134 feat_cache ._enc_feat_map = enc_feat_map
@@ -1164,6 +1168,7 @@ def _decode(
11641168 # This is to bypass an issue where frame[1] should be frame[2] and vise versa.
11651169 # Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
11661170 # Most likely due to an incorrect reshaping in the decoder.
1171+ start_expand_concat = time .time ()
11671172 fm1 , fm2 , fm3 , fm4 = out_ [:, 0 , :, :, :], out_ [:, 1 , :, :, :], out_ [:, 2 , :, :, :], out_ [:, 3 , :, :, :]
11681173 # When batch_size is 0, expand batch dim for concatenation
11691174 # else, expand frame dim for concatenation so that batch dim stays intact.
@@ -1177,6 +1182,8 @@ def _decode(
11771182 fm3 = jnp .expand_dims (fm3 , axis = axis )
11781183 fm4 = jnp .expand_dims (fm4 , axis = axis )
11791184 out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
1185+ out .block_until_ready ()
1186+ print (f"Decode step { i } expand+concat time: { time .time () - start_expand_concat } " )
11801187
11811188 feat_cache ._feat_map = dec_feat_map
11821189
0 commit comments