Skip to content

Commit 685173f

Browse files
committed
nan check in wantransformerblock
1 parent 91aa404 commit 685173f

1 file changed

Lines changed: 30 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040

4141
BlockSizes = common_types.BlockSizes
4242

43+
def check_nan(tensor: jax.Array, name: str):
44+
if jnp.isnan(tensor).any():
45+
print(f"[DEBUG NaN Check] NaNs detected in {name} on process {jax.process_index()}")
4346

4447
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, use_real: bool):
4548
h_dim = w_dim = 2 * (attention_head_dim // 6)
@@ -373,8 +376,15 @@ def __call__(
373376
deterministic: bool = True,
374377
rngs: nnx.Rngs = None,
375378
):
379+
check_nan(hidden_states, "TransformerBlock Input hidden_states")
380+
check_nan(encoder_hidden_states, "TransformerBlock Input encoder_hidden_states")
381+
check_nan(temb, "TransformerBlock Input temb")
382+
if rotary_emb is not None:
383+
check_nan(rotary_emb, "TransformerBlock Input rotary_emb")
376384
with self.conditional_named_scope("transformer_block"):
377385
with self.conditional_named_scope("adaln"):
386+
scale_shift_all = (self.adaln_scale_shift_table.value + temb.astype(jnp.float32))
387+
check_nan(scale_shift_all, "AdaLN scale_shift_all")
378388
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
379389
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
380390
)
@@ -385,9 +395,12 @@ def __call__(
385395
# 1. Self-attention
386396
with self.conditional_named_scope("self_attn"):
387397
with self.conditional_named_scope("self_attn_norm"):
388-
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
398+
norm_hidden_states = self.norm1(hidden_states.astype(jnp.float32))
399+
check_nan(norm_hidden_states, "Self-Attn norm1 output")
400+
norm_hidden_states = (norm_hidden_states * (1 + scale_msa) + shift_msa).astype(
389401
hidden_states.dtype
390402
)
403+
check_nan(norm_hidden_states, "Self-Attn norm_hidden_states after AdaLN")
391404
with self.conditional_named_scope("self_attn_attn"):
392405
attn_output = self.attn1(
393406
hidden_states=norm_hidden_states,
@@ -396,28 +409,42 @@ def __call__(
396409
deterministic=deterministic,
397410
rngs=rngs,
398411
)
412+
check_nan(attn_output, "Self-Attn attn_output (attn1)")
399413
with self.conditional_named_scope("self_attn_residual"):
400414
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
415+
check_nan(hidden_states, "Self-Attn hidden_states after residual")
401416

402417
# 2. Cross-attention
418+
residual = hidden_states
403419
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
420+
check_nan(norm_hidden_states, "Cross-Attn norm_hidden_states (norm2)")
404421
attn_output = self.attn2(
405422
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
406423
)
407-
hidden_states = hidden_states + attn_output
424+
check_nan(attn_output, "Cross-Attn attn_output (attn2)")
425+
hidden_states = residual + attn_output
426+
check_nan(hidden_states, "Cross-Attn hidden_states after residual")
408427

409428
# 3. Feed-forward
429+
residual = hidden_states
410430
with self.conditional_named_scope("mlp"):
411431
with self.conditional_named_scope("mlp_norm"):
412-
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
432+
norm_hidden_states = self.norm3(hidden_states.astype(jnp.float32))
433+
check_nan(norm_hidden_states, "MLP norm3 output")
434+
norm_hidden_states = (norm_hidden_states * (1 + c_scale_msa) + c_shift_msa).astype(
413435
hidden_states.dtype
414436
)
437+
check_nan(norm_hidden_states, "MLP norm_hidden_states after AdaLN")
438+
415439
with self.conditional_named_scope("mlp_ffn"):
416440
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
441+
check_nan(ff_output, "MLP ff_output")
442+
417443
with self.conditional_named_scope("mlp_residual"):
418444
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
419445
hidden_states.dtype
420446
)
447+
check_nan(hidden_states, "MLP hidden_states after residual (Block Output)")
421448
return hidden_states
422449

423450

0 commit comments

Comments
 (0)