Skip to content

Commit ed0655d

Browse files
committed
check_nan_attn corrected
1 parent c1041dd commit ed0655d

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,21 @@
5151
def _maybe_aqt_einsum(quant: Quant):
5252
return jnp.einsum if quant is None else quant.einsum()
5353

54-
def check_nan_attn(tensor: jax.Array, name: str, device_id: int):
55-
if tensor is None: return
54+
def check_nan_attn(tensor: jax.Array, name: str, tag: str = ""):
55+
if tensor is None:
56+
# This print is fine, it's not in JIT on None
57+
print(f"[DEBUG ATTN PY {jax.process_index()}] {tag} {name}: Tensor is None")
58+
return
59+
60+
# These are JAX boolean arrays (tracers when JITted)
5661
has_nans = jnp.isnan(tensor).any()
5762
has_infs = jnp.isinf(tensor).any()
58-
jax.debug.print(f"[DEBUG ATTN {device_id}] {name}: "
59-
"Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
60-
has_nans_val=has_nans, has_infs_val=has_infs)
61-
if has_nans or has_infs:
62-
# Optional: Print more stats if non-finite
63-
jax.debug.print(f" {name} shape: {tensor.shape}, dtype: {tensor.dtype}")
63+
64+
# Pass the tracers as keyword arguments to jax.debug.print
65+
jax.debug.print(f"[DEBUG ATTN JIT {jax.process_index()}] {tag} {name}: "
66+
"Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
67+
shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs)
68+
6469

6570

6671
def _check_attention_inputs(query: Array, key: Array, value: Array) -> None:

0 commit comments

Comments
 (0)