Skip to content

Commit 699af62

Browse files
committed
Added Debug to check vae encode and decode time
1 parent ad56886 commit 699af62

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from typing import Tuple, List, Sequence, Union, Optional
18-
18+
import time
1919
import flax
2020
import jax
2121
import jax.numpy as jnp
@@ -1096,7 +1096,10 @@ def encode(
10961096
self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
10971097
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
10981098
"""Encode video into latent distribution."""
1099+
s0 = time.perf_counter()
10991100
h = self._encode(x, feat_cache)
1101+
h.block_until_ready()
1102+
print(f"VAE Encode time: {time.perf_counter() - s0:.4f}s")
11001103
posterior = FlaxDiagonalGaussianDistribution(h)
11011104
if not return_dict:
11021105
return (posterior,)
@@ -1145,7 +1148,10 @@ def decode(
11451148
# reshape channel last for JAX
11461149
z = jnp.transpose(z, (0, 2, 3, 4, 1))
11471150
assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}"
1151+
s0 = time.perf_counter()
11481152
decoded = self._decode(z, feat_cache).sample
1153+
decoded.block_until_ready()
1154+
print(f"VAE Decode time: {time.perf_counter() - s0:.4f}s")
11491155
if not return_dict:
11501156
return (decoded,)
11511157
return FlaxDecoderOutput(sample=decoded)

0 commit comments

Comments
 (0)