|
15 | 15 | """ |
16 | 16 |
|
17 | 17 | from typing import Tuple, List, Sequence, Union, Optional |
18 | | - |
| 18 | +import time |
19 | 19 | import flax |
20 | 20 | import jax |
21 | 21 | import jax.numpy as jnp |
@@ -1096,7 +1096,10 @@ def encode( |
1096 | 1096 | self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True |
1097 | 1097 | ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: |
1098 | 1098 | """Encode video into latent distribution.""" |
| 1099 | + s0 = time.perf_counter() |
1099 | 1100 | h = self._encode(x, feat_cache) |
| 1101 | + h.block_until_ready() |
| 1102 | + print(f"VAE Encode time: {time.perf_counter() - s0:.4f}s") |
1100 | 1103 | posterior = FlaxDiagonalGaussianDistribution(h) |
1101 | 1104 | if not return_dict: |
1102 | 1105 | return (posterior,) |
@@ -1145,7 +1148,10 @@ def decode( |
1145 | 1148 | # reshape channel last for JAX |
1146 | 1149 | z = jnp.transpose(z, (0, 2, 3, 4, 1)) |
1147 | 1150 | assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" |
| 1151 | + s0 = time.perf_counter() |
1148 | 1152 | decoded = self._decode(z, feat_cache).sample |
| 1153 | + decoded.block_until_ready() |
| 1154 | + print(f"VAE Decode time: {time.perf_counter() - s0:.4f}s") |
1149 | 1155 | if not return_dict: |
1150 | 1156 | return (decoded,) |
1151 | 1157 | return FlaxDecoderOutput(sample=decoded) |
0 commit comments