Skip to content

Commit c1041dd

Browse files
committed
more debug added in attention_flax.py
1 parent cc73d06 commit c1041dd

2 files changed

Lines changed: 47 additions & 3 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@
5151
def _maybe_aqt_einsum(quant: Quant):
5252
return jnp.einsum if quant is None else quant.einsum()
5353

54+
def check_nan_attn(tensor: jax.Array, name: str, device_id: int):
55+
if tensor is None: return
56+
has_nans = jnp.isnan(tensor).any()
57+
has_infs = jnp.isinf(tensor).any()
58+
jax.debug.print(f"[DEBUG ATTN {device_id}] {name}: "
59+
"Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
60+
has_nans_val=has_nans, has_infs_val=has_infs)
61+
if has_nans or has_infs:
62+
# Optional: Print more stats if non-finite
63+
jax.debug.print(f" {name} shape: {tensor.shape}, dtype: {tensor.dtype}")
64+
5465

5566
def _check_attention_inputs(query: Array, key: Array, value: Array) -> None:
5667
"""Check attention inputs."""
@@ -945,7 +956,14 @@ def __call__(
945956
rotary_emb: Optional[jax.Array] = None,
946957
deterministic: bool = True,
947958
rngs: nnx.Rngs = None,
959+
tag: str = "attn"
948960
) -> jax.Array:
961+
check_nan_attn(hidden_states, "Input hidden_states", tag)
962+
if encoder_hidden_states is not None:
963+
check_nan_attn(encoder_hidden_states, "Input encoder_hidden_states", tag)
964+
if rotary_emb is not None:
965+
check_nan_attn(rotary_emb, "Input rotary_emb", tag)
966+
949967
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
950968
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
951969
dtype = hidden_states.dtype
@@ -959,60 +977,79 @@ def __call__(
959977
with self.conditional_named_scope("attn_qkv_proj"):
960978
with self.conditional_named_scope("proj_query"):
961979
query_proj = self.query(hidden_states)
980+
check_nan_attn(query_proj, "query_proj", tag)
962981
with self.conditional_named_scope("proj_key"):
963982
key_proj = self.key(encoder_hidden_states)
983+
check_nan_attn(key_proj, "key_proj", tag)
964984
with self.conditional_named_scope("proj_value"):
965985
value_proj = self.value(encoder_hidden_states)
986+
check_nan_attn(value_proj, "value_proj", tag)
966987

967988
if self.qk_norm:
968989
with self.conditional_named_scope("attn_q_norm"):
969990
query_proj = self.norm_q(query_proj)
991+
check_nan_attn(query_proj, "query_proj normed", tag)
970992
with self.conditional_named_scope("attn_k_norm"):
971993
key_proj = self.norm_k(key_proj)
994+
check_nan_attn(key_proj, "key_proj normed", tag)
972995

973996
if rotary_emb is not None: # Only for SELF-ATTENTION
974997
with self.conditional_named_scope("attn_rope"):
975998
# Unflatten is done HERE for RoPE
976999
query_proj = _unflatten_heads(query_proj, self.heads)
1000+
check_nan_attn(query_proj, "query_proj unflattened", tag)
9771001
key_proj = _unflatten_heads(key_proj, self.heads)
1002+
check_nan_attn(key_proj, "key_proj unflattened", tag)
9781003
value_proj = _unflatten_heads(value_proj, self.heads)
1004+
check_nan_attn(value_proj, "value_proj unflattened", tag)
9791005
# output of _unflatten_heads Batch, heads, seq_len, head_dim
9801006
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
981-
982-
1007+
check_nan_attn(query_proj, "query_proj after RoPE", tag)
1008+
check_nan_attn(key_proj, "key_proj after RoPE", tag)
9831009
query_proj = checkpoint_name(query_proj, "query_proj")
9841010
key_proj = checkpoint_name(key_proj, "key_proj")
9851011
value_proj = checkpoint_name(value_proj, "value_proj")
9861012
with self.conditional_named_scope("attn_compute"):
9871013
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1014+
check_nan_attn(attn_output, "attn_output from attention_op", tag)
9881015

9891016
else:
9901017
# NEW PATH for I2V CROSS-ATTENTION
9911018
with self.conditional_named_scope("proj_query"):
9921019
query_proj = self.query(hidden_states)
1020+
check_nan_attn(query_proj, "query_proj I2V", tag)
9931021
if self.qk_norm:
9941022
with self.conditional_named_scope("attn_q_norm"):
9951023
query_proj = self.norm_q(query_proj)
1024+
check_nan_attn(query_proj, "query_proj normed I2V", tag)
9961025

9971026
encoder_hidden_states_img = encoder_hidden_states[:, :self.image_seq_len, :]
9981027
encoder_hidden_states_text = encoder_hidden_states[:, self.image_seq_len:, :]
1028+
check_nan_attn(encoder_hidden_states_img, "EHS_img", tag)
1029+
check_nan_attn(encoder_hidden_states_text, "EHS_text", tag)
9991030

10001031
# Text K/V
10011032
with self.conditional_named_scope("proj_key"):
10021033
key_proj_text = self.key(encoder_hidden_states_text)
1034+
check_nan_attn(key_proj_text, "key_proj_text", tag)
10031035
if self.qk_norm:
10041036
with self.conditional_named_scope("attn_k_norm"):
10051037
key_proj_text = self.norm_k(key_proj_text)
1038+
check_nan_attn(key_proj_text, "key_proj_text normed", tag)
10061039
with self.conditional_named_scope("proj_value"):
10071040
value_proj_text = self.value(encoder_hidden_states_text)
1041+
check_nan_attn(value_proj_text, "value_proj_text", tag)
10081042

10091043
# Image K/V
10101044
with self.conditional_named_scope("add_proj_k"):
10111045
key_proj_img = self.add_k_proj(encoder_hidden_states_img)
1046+
check_nan_attn(key_proj_img, "key_proj_img", tag)
10121047
with self.conditional_named_scope("norm_add_k"):
10131048
key_proj_img = self.norm_added_k(key_proj_img)
1049+
check_nan_attn(key_proj_img, "key_proj_img normed", tag)
10141050
with self.conditional_named_scope("add_proj_v"):
10151051
value_proj_img = self.add_v_proj(encoder_hidden_states_img)
1052+
check_nan_attn(value_proj_img, "value_proj_img", tag)
10161053

10171054
# Checkpointing
10181055
query_proj = checkpoint_name(query_proj, "query_proj")
@@ -1024,19 +1061,25 @@ def __call__(
10241061
# Attention - tensors are (B, S, D)
10251062
with self.conditional_named_scope("cross_attn_text_apply"):
10261063
attn_output_text = self.attention_op.apply_attention(query_proj, key_proj_text, value_proj_text)
1064+
check_nan_attn(attn_output_text, "attn_output_text_h", tag)
10271065
with self.conditional_named_scope("norm_added_q"):
10281066
query_proj_img = self.norm_added_q(query_proj)
1067+
check_nan_attn(query_proj_img, "query_proj_img normed", tag)
10291068
with self.conditional_named_scope("cross_attn_img_apply"):
10301069
attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img)
1070+
check_nan_attn(attn_output_img, "attn_output_img", tag)
10311071

10321072
attn_output = attn_output_text + attn_output_img
1073+
check_nan_attn(attn_output, "attn_output final I2V", tag)
10331074

10341075
attn_output = attn_output.astype(dtype=dtype)
10351076
attn_output = checkpoint_name(attn_output, "attn_output")
10361077

10371078
with self.conditional_named_scope("attn_out_proj"):
10381079
hidden_states = self.proj_attn(attn_output)
1080+
check_nan_attn(hidden_states, "hidden_states after proj_attn", tag)
10391081
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1082+
check_nan_attn(hidden_states, "hidden_states after dropout", tag)
10401083
return hidden_states
10411084

10421085

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def __call__(
420420
rotary_emb=rotary_emb,
421421
deterministic=deterministic,
422422
rngs=rngs,
423+
tag="SELF",
423424
)
424425
check_nan(attn_output, "Self-Attn attn_output (attn1)")
425426
with self.conditional_named_scope("self_attn_residual"):
@@ -431,7 +432,7 @@ def __call__(
431432
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
432433
check_nan(norm_hidden_states, "Cross-Attn norm_hidden_states (norm2)")
433434
attn_output = self.attn2(
434-
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
435+
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs, tag="CROSS"
435436
)
436437
check_nan(attn_output, "Cross-Attn attn_output (attn2)")
437438
hidden_states = residual + attn_output

0 commit comments

Comments
 (0)