Skip to content

Commit c247d99

Browse files
committed
Added Debug to check vae encode and decode time
1 parent 6338698 commit c247d99

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17+
import time
1718
from typing import Tuple, List, Sequence, Union, Optional
1819

1920
import flax
@@ -1161,7 +1162,10 @@ def encode(
11611162
self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
11621163
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
11631164
"""Encode video into latent distribution."""
1165+
s0 = time.perf_counter()
11641166
h = self._encode(x, feat_cache)
1167+
h.block_until_ready()
1168+
print(f"VAE Encode time: {time.perf_counter() - s0:.4f}s")
11651169
posterior = WanDiagonalGaussianDistribution(h)
11661170
if not return_dict:
11671171
return (posterior,)
@@ -1216,7 +1220,10 @@ def decode(
12161220
# reshape channel last for JAX
12171221
z = jnp.transpose(z, (0, 2, 3, 4, 1))
12181222
assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}"
1223+
s0 = time.perf_counter()
12191224
decoded = self._decode(z, feat_cache).sample
1225+
decoded.block_until_ready()
1226+
print(f"VAE Decode time: {time.perf_counter() - s0:.4f}s")
12201227
if not return_dict:
12211228
return (decoded,)
12221229
return FlaxDecoderOutput(sample=decoded)

0 commit comments

Comments
 (0)