|
14 | 14 | limitations under the License. |
15 | 15 | """ |
16 | 16 |
|
| 17 | +import time |
17 | 18 | from typing import Tuple, List, Sequence, Union, Optional |
18 | 19 |
|
19 | 20 | import flax |
@@ -1161,7 +1162,10 @@ def encode( |
1161 | 1162 | self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True |
1162 | 1163 | ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: |
1163 | 1164 | """Encode video into latent distribution.""" |
| 1165 | + s0 = time.perf_counter() |
1164 | 1166 | h = self._encode(x, feat_cache) |
| 1167 | + h.block_until_ready() |
| 1168 | + print(f"VAE Encode time: {time.perf_counter() - s0:.4f}s") |
1165 | 1169 | posterior = WanDiagonalGaussianDistribution(h) |
1166 | 1170 | if not return_dict: |
1167 | 1171 | return (posterior,) |
@@ -1216,7 +1220,10 @@ def decode( |
1216 | 1220 | # reshape channel last for JAX |
1217 | 1221 | z = jnp.transpose(z, (0, 2, 3, 4, 1)) |
1218 | 1222 | assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}" |
| 1223 | + s0 = time.perf_counter() |
1219 | 1224 | decoded = self._decode(z, feat_cache).sample |
| 1225 | + decoded.block_until_ready() |
| 1226 | + print(f"VAE Decode time: {time.perf_counter() - s0:.4f}s") |
1220 | 1227 | if not return_dict: |
1221 | 1228 | return (decoded,) |
1222 | 1229 | return FlaxDecoderOutput(sample=decoded) |
0 commit comments