Skip to content

Commit e160f84

Browse files
committed
check_nan_jit modified
1 parent 5eef84f commit e160f84

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
def check_nan_jit(tensor: jax.Array, name: str, step: jax.Array):
3434
if tensor is None:
3535
return
36-
3736
has_nans = jnp.isnan(tensor).any()
3837
has_infs = jnp.isinf(tensor).any()
39-
jax.debug.print(f"[DEBUG SCHEDULER {jax.process_index()}] Step: {{step}} - {name}: "
40-
"Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
41-
step=step, shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs)
38+
if step is None:
39+
step = -1
4240

41+
# Print the actual dtype of the tensor's data
42+
jax.debug.print(f"[DEBUG SCHEDULER {jax.process_index()}] Step: {{step}} - {name}: "
43+
"Shape: {shape}, tensor.dtype: {dtype}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
44+
step=step, shape=tensor.shape, dtype=tensor.dtype, has_nans_val=has_nans, has_infs_val=has_infs)
4345

4446
@flax.struct.dataclass
4547
class UniPCMultistepSchedulerState:

0 commit comments

Comments
 (0)