|
19 | 19 | from typing import List, Union, Optional, Tuple |
20 | 20 | from ...pyconfig import HyperParameters |
21 | 21 | from functools import partial |
| 22 | +import numpy as np |
22 | 23 | from flax import nnx |
23 | 24 | from flax.linen import partitioning as nn_partitioning |
24 | 25 | import jax |
@@ -234,15 +235,26 @@ def __call__( |
234 | 235 | scheduler_state=scheduler_state, |
235 | 236 | rng=inference_rng, |
236 | 237 | ) |
| 238 | + max_logging.log(f"[DEBUG CALL] latents shape after loop: {latents.shape}") |
| 239 | + max_logging.log(f"[DEBUG CALL] NaNs in latents AFTER loop: {jnp.isnan(latents).any()}, Infs: {jnp.isinf(latents).any()}") |
237 | 240 | if self.config.expand_timesteps: |
238 | 241 | latents = (1 - first_frame_mask) * condition + first_frame_mask * latents |
| 242 | + max_logging.log(f"[DEBUG CALL] NaNs in latents AFTER frame mask: {jnp.isnan(latents).any()}") |
239 | 243 | latents_bcthw = jnp.transpose(latents, (0, 4, 1, 2, 3)) |
| 244 | + max_logging.log(f"[DEBUG CALL] NaNs in latents BEFORE denorm: {jnp.isnan(latents_bcthw).any()}") |
240 | 245 | latents_denorm_bcthw = self._denormalize_latents(latents_bcthw) |
| 246 | + max_logging.log(f"[DEBUG CALL] NaNs in latents AFTER denorm: {jnp.isnan(latents_denorm_bcthw).any()}") |
241 | 247 |
|
242 | 248 |
|
243 | 249 | if output_type == "latent": |
244 | 250 | return jnp.transpose(latents_denorm_bcthw, (0, 2, 3, 4, 1)) |
245 | | - return self._decode_latents_to_video(latents_denorm_bcthw) |
| 251 | + max_logging.log(f"[DEBUG CALL] NaNs in latents BEFORE decode: {jnp.isnan(latents_denorm_bcthw).any()}") |
| 252 | + decoded_video = self._decode_latents_to_video(latents_denorm_bcthw) |
| 253 | + if isinstance(decoded_video, np.ndarray): |
| 254 | + max_logging.log(f"[DEBUG CALL] NaNs in video AFTER decode: {np.isnan(decoded_video).any()}") |
| 255 | + else: |
| 256 | + max_logging.log(f"[DEBUG CALL] Decoded video type: {type(decoded_video)}") |
| 257 | + return decoded_video |
246 | 258 |
|
247 | 259 | def run_inference_2_1_i2v( |
248 | 260 | graphdef, sharded_state, rest_of_state, |
|
0 commit comments