Skip to content

Commit cc73d06

Browse files
committed
check_nan fn corrected
1 parent ed9e8e6 commit cc73d06

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,20 @@
4141
BlockSizes = common_types.BlockSizes
4242

4343
def check_nan(tensor: jax.Array, name: str):
44+
if tensor is None:
45+
# jax.debug.print works fine with regular python strings and values
46+
print(f"[DEBUG NaN Check] {name} on process {jax.process_index()}: Tensor is None")
47+
return
48+
4449
has_nans = jnp.isnan(tensor).any()
4550
has_infs = jnp.isinf(tensor).any()
46-
# Use jax.debug.print to print during JITted execution
51+
52+
# Pass the JAX arrays (has_nans, has_infs) as kwargs
53+
# Use placeholders {} in the f-string for these runtime values
4754
jax.debug.print(f"[DEBUG NaN Check] {name} on process {jax.process_index()}: "
48-
f"Has NaNs: {has_nans}, Has Infs: {has_infs}")
55+
"Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
56+
has_nans_val=has_nans, has_infs_val=has_infs)
57+
4958

5059
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, use_real: bool):
5160
h_dim = w_dim = 2 * (attention_head_dim // 6)

0 commit comments

Comments
 (0)