Skip to content

Commit 8fbb1bc

Browse files
committed
debug time
1 parent 3a1ccfd commit 8fbb1bc

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import flax
2020
import jax
2121
import jax.numpy as jnp
22+
import time
2223
from jax import tree_util
2324
from flax import nnx
2425
from ...configuration_utils import ConfigMixin
@@ -1124,7 +1125,10 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11241125
feat_cache=enc_feat_map,
11251126
feat_idx=enc_conv_idx,
11261127
)
1128+
start_concat = time.time()
11271129
out = jnp.concatenate([out, out_], axis=1)
1130+
out.block_until_ready()
1131+
print(f"Encode step {i} concat time: {time.time() - start_concat}")
11281132

11291133
# Update back to the wrapper object if needed, but for result we use local vars
11301134
feat_cache._enc_feat_map = enc_feat_map
@@ -1164,6 +1168,7 @@ def _decode(
11641168
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
11651169
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
11661170
# Most likely due to an incorrect reshaping in the decoder.
1171+
start_expand_concat = time.time()
11671172
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
11681173
# When batch_size is 0, expand batch dim for concatenation
11691174
# else, expand frame dim for concatenation so that batch dim stays intact.
@@ -1177,6 +1182,8 @@ def _decode(
11771182
fm3 = jnp.expand_dims(fm3, axis=axis)
11781183
fm4 = jnp.expand_dims(fm4, axis=axis)
11791184
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1185+
out.block_until_ready()
1186+
print(f"Decode step {i} expand+concat time: {time.time() - start_expand_concat}")
11801187

11811188
feat_cache._feat_map = dec_feat_map
11821189

0 commit comments

Comments
 (0)