Skip to content

Commit 91aa404

Browse files
committed
debug for nan pixels
1 parent 1b6b816 commit 91aa404

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import List, Union, Optional, Tuple
2020
from ...pyconfig import HyperParameters
2121
from functools import partial
22+
import numpy as np
2223
from flax import nnx
2324
from flax.linen import partitioning as nn_partitioning
2425
import jax
@@ -234,15 +235,26 @@ def __call__(
234235
scheduler_state=scheduler_state,
235236
rng=inference_rng,
236237
)
238+
max_logging.log(f"[DEBUG CALL] latents shape after loop: {latents.shape}")
239+
max_logging.log(f"[DEBUG CALL] NaNs in latents AFTER loop: {jnp.isnan(latents).any()}, Infs: {jnp.isinf(latents).any()}")
237240
if self.config.expand_timesteps:
238241
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
242+
max_logging.log(f"[DEBUG CALL] NaNs in latents AFTER frame mask: {jnp.isnan(latents).any()}")
239243
latents_bcthw = jnp.transpose(latents, (0, 4, 1, 2, 3))
244+
max_logging.log(f"[DEBUG CALL] NaNs in latents BEFORE denorm: {jnp.isnan(latents_bcthw).any()}")
240245
latents_denorm_bcthw = self._denormalize_latents(latents_bcthw)
246+
max_logging.log(f"[DEBUG CALL] NaNs in latents AFTER denorm: {jnp.isnan(latents_denorm_bcthw).any()}")
241247

242248

243249
if output_type == "latent":
244250
return jnp.transpose(latents_denorm_bcthw, (0, 2, 3, 4, 1))
245-
return self._decode_latents_to_video(latents_denorm_bcthw)
251+
max_logging.log(f"[DEBUG CALL] NaNs in latents BEFORE decode: {jnp.isnan(latents_denorm_bcthw).any()}")
252+
decoded_video = self._decode_latents_to_video(latents_denorm_bcthw)
253+
if isinstance(decoded_video, np.ndarray):
254+
max_logging.log(f"[DEBUG CALL] NaNs in video AFTER decode: {np.isnan(decoded_video).any()}")
255+
else:
256+
max_logging.log(f"[DEBUG CALL] Decoded video type: {type(decoded_video)}")
257+
return decoded_video
246258

247259
def run_inference_2_1_i2v(
248260
graphdef, sharded_state, rest_of_state,

0 commit comments

Comments
 (0)