Skip to content

Commit e7d7ba1

Browse files
committed
debug for NaNs
1 parent 7572e09 commit e7d7ba1

3 files changed

Lines changed: 30 additions & 85 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,6 @@
5151
def _maybe_aqt_einsum(quant: Quant):
5252
return jnp.einsum if quant is None else quant.einsum()
5353

54-
def check_nan_attn(tensor: jax.Array, name: str, tag: str = ""):
55-
if tensor is None:
56-
# This print is fine, it's not in JIT on None
57-
print(f"[DEBUG ATTN PY {jax.process_index()}] {tag} {name}: Tensor is None")
58-
return
59-
60-
# These are JAX boolean arrays (tracers when JITted)
61-
has_nans = jnp.isnan(tensor).any()
62-
has_infs = jnp.isinf(tensor).any()
63-
64-
# Pass the tracers as keyword arguments to jax.debug.print
65-
jax.debug.print(f"[DEBUG ATTN JIT {jax.process_index()}] {tag} {name}: "
66-
"Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
67-
shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs)
68-
6954

7055

7156
def _check_attention_inputs(query: Array, key: Array, value: Array) -> None:
@@ -961,13 +946,7 @@ def __call__(
961946
rotary_emb: Optional[jax.Array] = None,
962947
deterministic: bool = True,
963948
rngs: nnx.Rngs = None,
964-
tag: str = "attn"
965949
) -> jax.Array:
966-
check_nan_attn(hidden_states, "Input hidden_states", tag)
967-
if encoder_hidden_states is not None:
968-
check_nan_attn(encoder_hidden_states, "Input encoder_hidden_states", tag)
969-
if rotary_emb is not None:
970-
check_nan_attn(rotary_emb, "Input rotary_emb", tag)
971950

972951
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
973952
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
@@ -982,79 +961,58 @@ def __call__(
982961
with self.conditional_named_scope("attn_qkv_proj"):
983962
with self.conditional_named_scope("proj_query"):
984963
query_proj = self.query(hidden_states)
985-
check_nan_attn(query_proj, "query_proj", tag)
986964
with self.conditional_named_scope("proj_key"):
987965
key_proj = self.key(encoder_hidden_states)
988-
check_nan_attn(key_proj, "key_proj", tag)
989966
with self.conditional_named_scope("proj_value"):
990967
value_proj = self.value(encoder_hidden_states)
991-
check_nan_attn(value_proj, "value_proj", tag)
992968

993969
if self.qk_norm:
994970
with self.conditional_named_scope("attn_q_norm"):
995971
query_proj = self.norm_q(query_proj)
996-
check_nan_attn(query_proj, "query_proj normed", tag)
997972
with self.conditional_named_scope("attn_k_norm"):
998973
key_proj = self.norm_k(key_proj)
999-
check_nan_attn(key_proj, "key_proj normed", tag)
1000974

1001975
if rotary_emb is not None: # Only for SELF-ATTENTION
1002976
with self.conditional_named_scope("attn_rope"):
1003977
# Unflatten is done HERE for RoPE
1004978
query_proj = _unflatten_heads(query_proj, self.heads)
1005-
check_nan_attn(query_proj, "query_proj unflattened", tag)
1006979
key_proj = _unflatten_heads(key_proj, self.heads)
1007-
check_nan_attn(key_proj, "key_proj unflattened", tag)
1008980
value_proj = _unflatten_heads(value_proj, self.heads)
1009-
check_nan_attn(value_proj, "value_proj unflattened", tag)
1010981
# output of _unflatten_heads Batch, heads, seq_len, head_dim
1011982
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
1012-
check_nan_attn(query_proj, "query_proj after RoPE", tag)
1013-
check_nan_attn(key_proj, "key_proj after RoPE", tag)
1014983
query_proj = checkpoint_name(query_proj, "query_proj")
1015984
key_proj = checkpoint_name(key_proj, "key_proj")
1016985
value_proj = checkpoint_name(value_proj, "value_proj")
1017986
with self.conditional_named_scope("attn_compute"):
1018987
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1019-
check_nan_attn(attn_output, "attn_output from attention_op", tag)
1020988

1021989
else:
1022990
# NEW PATH for I2V CROSS-ATTENTION
1023991
with self.conditional_named_scope("proj_query"):
1024992
query_proj = self.query(hidden_states)
1025-
check_nan_attn(query_proj, "query_proj I2V", tag)
1026993
if self.qk_norm:
1027994
with self.conditional_named_scope("attn_q_norm"):
1028995
query_proj = self.norm_q(query_proj)
1029-
check_nan_attn(query_proj, "query_proj normed I2V", tag)
1030996

1031997
encoder_hidden_states_img = encoder_hidden_states[:, :self.image_seq_len, :]
1032998
encoder_hidden_states_text = encoder_hidden_states[:, self.image_seq_len:, :]
1033-
check_nan_attn(encoder_hidden_states_img, "EHS_img", tag)
1034-
check_nan_attn(encoder_hidden_states_text, "EHS_text", tag)
1035999

10361000
# Text K/V
10371001
with self.conditional_named_scope("proj_key"):
10381002
key_proj_text = self.key(encoder_hidden_states_text)
1039-
check_nan_attn(key_proj_text, "key_proj_text", tag)
10401003
if self.qk_norm:
10411004
with self.conditional_named_scope("attn_k_norm"):
10421005
key_proj_text = self.norm_k(key_proj_text)
1043-
check_nan_attn(key_proj_text, "key_proj_text normed", tag)
10441006
with self.conditional_named_scope("proj_value"):
10451007
value_proj_text = self.value(encoder_hidden_states_text)
1046-
check_nan_attn(value_proj_text, "value_proj_text", tag)
10471008

10481009
# Image K/V
10491010
with self.conditional_named_scope("add_proj_k"):
10501011
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1051-
check_nan_attn(key_proj_img, "key_proj_img", tag)
10521012
with self.conditional_named_scope("norm_add_k"):
10531013
key_proj_img = self.norm_added_k(key_proj_img)
1054-
check_nan_attn(key_proj_img, "key_proj_img normed", tag)
10551014
with self.conditional_named_scope("add_proj_v"):
10561015
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1057-
check_nan_attn(value_proj_img, "value_proj_img", tag)
10581016

10591017
# Checkpointing
10601018
query_proj = checkpoint_name(query_proj, "query_proj")
@@ -1066,25 +1024,19 @@ def __call__(
10661024
# Attention - tensors are (B, S, D)
10671025
with self.conditional_named_scope("cross_attn_text_apply"):
10681026
attn_output_text = self.attention_op.apply_attention(query_proj, key_proj_text, value_proj_text)
1069-
check_nan_attn(attn_output_text, "attn_output_text_h", tag)
10701027
with self.conditional_named_scope("norm_added_q"):
10711028
query_proj_img = self.norm_added_q(query_proj)
1072-
check_nan_attn(query_proj_img, "query_proj_img normed", tag)
10731029
with self.conditional_named_scope("cross_attn_img_apply"):
10741030
attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img)
1075-
check_nan_attn(attn_output_img, "attn_output_img", tag)
10761031

10771032
attn_output = attn_output_text + attn_output_img
1078-
check_nan_attn(attn_output, "attn_output final I2V", tag)
10791033

10801034
attn_output = attn_output.astype(dtype=dtype)
10811035
attn_output = checkpoint_name(attn_output, "attn_output")
10821036

10831037
with self.conditional_named_scope("attn_out_proj"):
10841038
hidden_states = self.proj_attn(attn_output)
1085-
check_nan_attn(hidden_states, "hidden_states after proj_attn", tag)
10861039
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1087-
check_nan_attn(hidden_states, "hidden_states after dropout", tag)
10881040
return hidden_states
10891041

10901042

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,6 @@
4040

4141
BlockSizes = common_types.BlockSizes
4242

43-
def check_nan(tensor: jax.Array, name: str):
44-
if tensor is None:
45-
# jax.debug.print works fine with regular python strings and values
46-
print(f"[DEBUG NaN Check] {name} on process {jax.process_index()}: Tensor is None")
47-
return
48-
49-
has_nans = jnp.isnan(tensor).any()
50-
has_infs = jnp.isinf(tensor).any()
51-
52-
# Pass the JAX arrays (has_nans, has_infs) as kwargs
53-
# Use placeholders {} in the f-string for these runtime values
54-
jax.debug.print(f"[DEBUG NaN Check] {name} on process {jax.process_index()}: "
55-
"Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
56-
has_nans_val=has_nans, has_infs_val=has_infs)
57-
5843

5944
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, use_real: bool):
6045
h_dim = w_dim = 2 * (attention_head_dim // 6)
@@ -388,15 +373,8 @@ def __call__(
388373
deterministic: bool = True,
389374
rngs: nnx.Rngs = None,
390375
):
391-
check_nan(hidden_states, "TransformerBlock Input hidden_states")
392-
check_nan(encoder_hidden_states, "TransformerBlock Input encoder_hidden_states")
393-
check_nan(temb, "TransformerBlock Input temb")
394-
if rotary_emb is not None:
395-
check_nan(rotary_emb, "TransformerBlock Input rotary_emb")
396376
with self.conditional_named_scope("transformer_block"):
397377
with self.conditional_named_scope("adaln"):
398-
scale_shift_all = (self.adaln_scale_shift_table.value + temb.astype(jnp.float32))
399-
check_nan(scale_shift_all, "AdaLN scale_shift_all")
400378
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
401379
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
402380
)
@@ -408,56 +386,44 @@ def __call__(
408386
with self.conditional_named_scope("self_attn"):
409387
with self.conditional_named_scope("self_attn_norm"):
410388
norm_hidden_states = self.norm1(hidden_states.astype(jnp.float32))
411-
check_nan(norm_hidden_states, "Self-Attn norm1 output")
412389
norm_hidden_states = (norm_hidden_states * (1 + scale_msa) + shift_msa).astype(
413390
hidden_states.dtype
414391
)
415-
check_nan(norm_hidden_states, "Self-Attn norm_hidden_states after AdaLN")
416392
with self.conditional_named_scope("self_attn_attn"):
417393
attn_output = self.attn1(
418394
hidden_states=norm_hidden_states,
419395
encoder_hidden_states=norm_hidden_states,
420396
rotary_emb=rotary_emb,
421397
deterministic=deterministic,
422398
rngs=rngs,
423-
tag="SELF",
424399
)
425-
check_nan(attn_output, "Self-Attn attn_output (attn1)")
426400
with self.conditional_named_scope("self_attn_residual"):
427401
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
428-
check_nan(hidden_states, "Self-Attn hidden_states after residual")
429402

430403
# 2. Cross-attention
431404
residual = hidden_states
432405
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
433-
check_nan(norm_hidden_states, "Cross-Attn norm_hidden_states (norm2)")
434406
attn_output = self.attn2(
435-
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs, tag="CROSS"
407+
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
436408
)
437-
check_nan(attn_output, "Cross-Attn attn_output (attn2)")
438409
hidden_states = residual + attn_output
439-
check_nan(hidden_states, "Cross-Attn hidden_states after residual")
440410

441411
# 3. Feed-forward
442412
residual = hidden_states
443413
with self.conditional_named_scope("mlp"):
444414
with self.conditional_named_scope("mlp_norm"):
445415
norm_hidden_states = self.norm3(hidden_states.astype(jnp.float32))
446-
check_nan(norm_hidden_states, "MLP norm3 output")
447416
norm_hidden_states = (norm_hidden_states * (1 + c_scale_msa) + c_shift_msa).astype(
448417
hidden_states.dtype
449418
)
450-
check_nan(norm_hidden_states, "MLP norm_hidden_states after AdaLN")
451419

452420
with self.conditional_named_scope("mlp_ffn"):
453421
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
454-
check_nan(ff_output, "MLP ff_output")
455422

456423
with self.conditional_named_scope("mlp_residual"):
457424
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
458425
hidden_states.dtype
459426
)
460-
check_nan(hidden_states, "MLP hidden_states after residual (Block Output)")
461427
return hidden_states
462428

463429

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030
add_noise_common,
3131
)
3232

33+
def check_nan_jit(tensor: jax.Array, name: str, step: jax.Array):
34+
if tensor is None:
35+
return
36+
37+
has_nans = jnp.isnan(tensor).any()
38+
has_infs = jnp.isinf(tensor).any()
39+
jax.debug.print(f"[DEBUG SCHEDULER {jax.process_index()}] Step: {{step}} - {name}: "
40+
"Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
41+
step=step, shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs)
42+
3343

3444
@flax.struct.dataclass
3545
class UniPCMultistepSchedulerState:
@@ -285,14 +295,18 @@ def convert_model_output(
285295
state: UniPCMultistepSchedulerState,
286296
model_output: jnp.ndarray,
287297
sample: jnp.ndarray,
298+
step: jax.Array,
288299
) -> jnp.ndarray:
289300
"""
290301
Converts the model output based on the prediction type and current state.
291302
"""
292303
sigma = state.sigmas[state.step_index] # Current sigma
304+
check_nan_jit(sigma, "convert_model_output sigma", step)
293305

294306
# Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t
295307
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
308+
check_nan_jit(alpha_t, "convert_model_output alpha_t", step)
309+
check_nan_jit(sigma_t, "convert_model_output sigma_t", step)
296310

297311
if self.config.predict_x0:
298312
if self.config.prediction_type == "epsilon":
@@ -310,6 +324,7 @@ def convert_model_output(
310324
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
311325
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
312326
)
327+
check_nan_jit(x0_pred, "convert_model_output x0_pred", step)
313328

314329
if self.config.thresholding:
315330
raise NotImplementedError("Dynamic thresholding isn't implemented.")
@@ -336,6 +351,7 @@ def multistep_uni_p_bh_update(
336351
model_output: jnp.ndarray,
337352
sample: jnp.ndarray,
338353
order: int,
354+
step: jax.Array,
339355
) -> jnp.ndarray:
340356
"""
341357
One step for the UniP (B(h) version) - the Predictor.
@@ -358,6 +374,7 @@ def multistep_uni_p_bh_update(
358374
lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10)
359375

360376
h = lambda_t - lambda_s0
377+
check_nan_jit(h, "predictor h", step)
361378

362379
def rk_d1_loop_body(i, carry):
363380
# Loop from i = 0 to order-2
@@ -371,6 +388,7 @@ def rk_d1_loop_body(i, carry):
371388

372389
rk = (lambda_si - lambda_s0) / h
373390
Di = (mi - m0) / rk
391+
check_nan_jit(Di, f"predictor Di[{i}]", step)
374392

375393
rks = rks.at[i].set(rk)
376394
D1s = D1s.at[i].set(Di)
@@ -467,7 +485,7 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
467485
else: # Predict epsilon
468486
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
469487
x_t = x_t_ - sigma_t * B_h * pred_res
470-
488+
check_nan_jit(x_t, "predictor x_t", step)
471489
return x_t.astype(x.dtype)
472490

473491
def multistep_uni_c_bh_update(
@@ -477,6 +495,7 @@ def multistep_uni_c_bh_update(
477495
last_sample: jnp.ndarray, # Sample after predictor `x_{t-1}`
478496
this_sample: jnp.ndarray, # Sample before corrector `x_t` (after predictor step)
479497
order: int,
498+
step: jax.Array,
480499
) -> jnp.ndarray:
481500
"""
482501
One step for the UniC (B(h) version) - the Corrector.
@@ -620,7 +639,8 @@ def solve_for_rhos(R_mat, b_vec, current_order):
620639
else:
621640
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
622641
x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t)
623-
642+
643+
check_nan_jit(x_t, "corrector x_t", step)
624644
return x_t.astype(x.dtype)
625645

626646
def index_for_timestep(
@@ -674,6 +694,10 @@ def step(
674694
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
675695
the multistep UniPC.
676696
"""
697+
step_val = state.step_index # For debug, might be None initially
698+
699+
check_nan_jit(model_output, "step input model_output", step_val)
700+
check_nan_jit(sample, "step input sample", step_val)
677701

678702
sample = sample.astype(jnp.float32)
679703

@@ -685,6 +709,7 @@ def step(
685709
# Initialize step_index if it's the first step
686710
if state.step_index is None:
687711
state = self._init_step_index(state, timestep_scalar)
712+
step_val = state.step_index
688713

689714
# Determine if corrector should be used
690715
use_corrector = (
@@ -695,6 +720,7 @@ def step(
695720

696721
# Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type
697722
model_output_for_history = self.convert_model_output(state, model_output, sample)
723+
check_nan_jit(model_output_for_history, "model_output_for_history", step_val)
698724

699725
# Apply corrector if applicable
700726
sample = jax.lax.cond(
@@ -708,6 +734,7 @@ def step(
708734
),
709735
lambda: sample,
710736
)
737+
check_nan_jit(sample, "sample_corrected", step_val)
711738

712739
# Update history buffers (model_outputs and timestep_list)
713740
# Shift existing elements to the left and add new one at the end.

0 commit comments

Comments
 (0)