|
51 | 51 | def _maybe_aqt_einsum(quant: Quant): |
52 | 52 | return jnp.einsum if quant is None else quant.einsum() |
53 | 53 |
|
54 | | -def check_nan_attn(tensor: jax.Array, name: str, device_id: int): |
55 | | - if tensor is None: return |
| 54 | +def check_nan_attn(tensor: jax.Array, name: str, tag: str = ""): |
| 55 | + if tensor is None: |
| 56 | + # This print is fine, it's not in JIT on None |
| 57 | + print(f"[DEBUG ATTN PY {jax.process_index()}] {tag} {name}: Tensor is None") |
| 58 | + return |
| 59 | + |
| 60 | + # These are JAX boolean arrays (tracers when JITted) |
56 | 61 | has_nans = jnp.isnan(tensor).any() |
57 | 62 | has_infs = jnp.isinf(tensor).any() |
58 | | - jax.debug.print(f"[DEBUG ATTN {device_id}] {name}: " |
59 | | - "Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}", |
60 | | - has_nans_val=has_nans, has_infs_val=has_infs) |
61 | | - if has_nans or has_infs: |
62 | | - # Optional: Print more stats if non-finite |
63 | | - jax.debug.print(f" {name} shape: {tensor.shape}, dtype: {tensor.dtype}") |
| 63 | + |
| 64 | + # Pass the tracers as keyword arguments to jax.debug.print |
| 65 | + jax.debug.print(f"[DEBUG ATTN JIT {jax.process_index()}] {tag} {name}: " |
| 66 | + "Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}", |
| 67 | + shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs) |
| 68 | + |
64 | 69 |
|
65 | 70 |
|
66 | 71 | def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: |
|
0 commit comments