Skip to content

Commit ed9e8e6

Browse files
committed
check_nan fn corrected
1 parent 685173f commit ed9e8e6

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@
4141
BlockSizes = common_types.BlockSizes
4242

4343
def check_nan(tensor: jax.Array, name: str):
44-
if jnp.isnan(tensor).any():
45-
print(f"[DEBUG NaN Check] NaNs detected in {name} on process {jax.process_index()}")
44+
has_nans = jnp.isnan(tensor).any()
45+
has_infs = jnp.isinf(tensor).any()
46+
# Use jax.debug.print to print during JITted execution
47+
jax.debug.print(f"[DEBUG NaN Check] {name} on process {jax.process_index()}: "
48+
f"Has NaNs: {has_nans}, Has Infs: {has_infs}")
4649

4750
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, use_real: bool):
4851
h_dim = w_dim = 2 * (attention_head_dim // 6)

0 commit comments

Comments
 (0)