5151def _maybe_aqt_einsum (quant : Quant ):
5252 return jnp .einsum if quant is None else quant .einsum ()
5353
54+ def check_nan_attn (tensor : jax .Array , name : str , device_id : int ):
55+ if tensor is None : return
56+ has_nans = jnp .isnan (tensor ).any ()
57+ has_infs = jnp .isinf (tensor ).any ()
58+ jax .debug .print (f"[DEBUG ATTN { device_id } ] { name } : "
59+ "Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}" ,
60+ has_nans_val = has_nans , has_infs_val = has_infs )
61+ if has_nans or has_infs :
62+ # Optional: Print more stats if non-finite
63+ jax .debug .print (f" { name } shape: { tensor .shape } , dtype: { tensor .dtype } " )
64+
5465
5566def _check_attention_inputs (query : Array , key : Array , value : Array ) -> None :
5667 """Check attention inputs."""
@@ -945,7 +956,14 @@ def __call__(
945956 rotary_emb : Optional [jax .Array ] = None ,
946957 deterministic : bool = True ,
947958 rngs : nnx .Rngs = None ,
959+ tag : str = "attn"
948960 ) -> jax .Array :
961+ check_nan_attn (hidden_states , "Input hidden_states" , tag )
962+ if encoder_hidden_states is not None :
963+ check_nan_attn (encoder_hidden_states , "Input encoder_hidden_states" , tag )
964+ if rotary_emb is not None :
965+ check_nan_attn (rotary_emb , "Input rotary_emb" , tag )
966+
949967 hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
950968 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
951969 dtype = hidden_states .dtype
@@ -959,60 +977,79 @@ def __call__(
959977 with self .conditional_named_scope ("attn_qkv_proj" ):
960978 with self .conditional_named_scope ("proj_query" ):
961979 query_proj = self .query (hidden_states )
980+ check_nan_attn (query_proj , "query_proj" , tag )
962981 with self .conditional_named_scope ("proj_key" ):
963982 key_proj = self .key (encoder_hidden_states )
983+ check_nan_attn (key_proj , "key_proj" , tag )
964984 with self .conditional_named_scope ("proj_value" ):
965985 value_proj = self .value (encoder_hidden_states )
986+ check_nan_attn (value_proj , "value_proj" , tag )
966987
967988 if self .qk_norm :
968989 with self .conditional_named_scope ("attn_q_norm" ):
969990 query_proj = self .norm_q (query_proj )
991+ check_nan_attn (query_proj , "query_proj normed" , tag )
970992 with self .conditional_named_scope ("attn_k_norm" ):
971993 key_proj = self .norm_k (key_proj )
994+ check_nan_attn (key_proj , "key_proj normed" , tag )
972995
973996 if rotary_emb is not None : # Only for SELF-ATTENTION
974997 with self .conditional_named_scope ("attn_rope" ):
975998 # Unflatten is done HERE for RoPE
976999 query_proj = _unflatten_heads (query_proj , self .heads )
1000+ check_nan_attn (query_proj , "query_proj unflattened" , tag )
9771001 key_proj = _unflatten_heads (key_proj , self .heads )
1002+ check_nan_attn (key_proj , "key_proj unflattened" , tag )
9781003 value_proj = _unflatten_heads (value_proj , self .heads )
1004+ check_nan_attn (value_proj , "value_proj unflattened" , tag )
9791005 # output of _unflatten_heads Batch, heads, seq_len, head_dim
9801006 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
981-
982-
1007+ check_nan_attn ( query_proj , "query_proj after RoPE" , tag )
1008+ check_nan_attn ( key_proj , "key_proj after RoPE" , tag )
9831009 query_proj = checkpoint_name (query_proj , "query_proj" )
9841010 key_proj = checkpoint_name (key_proj , "key_proj" )
9851011 value_proj = checkpoint_name (value_proj , "value_proj" )
9861012 with self .conditional_named_scope ("attn_compute" ):
9871013 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
1014+ check_nan_attn (attn_output , "attn_output from attention_op" , tag )
9881015
9891016 else :
9901017 # NEW PATH for I2V CROSS-ATTENTION
9911018 with self .conditional_named_scope ("proj_query" ):
9921019 query_proj = self .query (hidden_states )
1020+ check_nan_attn (query_proj , "query_proj I2V" , tag )
9931021 if self .qk_norm :
9941022 with self .conditional_named_scope ("attn_q_norm" ):
9951023 query_proj = self .norm_q (query_proj )
1024+ check_nan_attn (query_proj , "query_proj normed I2V" , tag )
9961025
9971026 encoder_hidden_states_img = encoder_hidden_states [:, :self .image_seq_len , :]
9981027 encoder_hidden_states_text = encoder_hidden_states [:, self .image_seq_len :, :]
1028+ check_nan_attn (encoder_hidden_states_img , "EHS_img" , tag )
1029+ check_nan_attn (encoder_hidden_states_text , "EHS_text" , tag )
9991030
10001031 # Text K/V
10011032 with self .conditional_named_scope ("proj_key" ):
10021033 key_proj_text = self .key (encoder_hidden_states_text )
1034+ check_nan_attn (key_proj_text , "key_proj_text" , tag )
10031035 if self .qk_norm :
10041036 with self .conditional_named_scope ("attn_k_norm" ):
10051037 key_proj_text = self .norm_k (key_proj_text )
1038+ check_nan_attn (key_proj_text , "key_proj_text normed" , tag )
10061039 with self .conditional_named_scope ("proj_value" ):
10071040 value_proj_text = self .value (encoder_hidden_states_text )
1041+ check_nan_attn (value_proj_text , "value_proj_text" , tag )
10081042
10091043 # Image K/V
10101044 with self .conditional_named_scope ("add_proj_k" ):
10111045 key_proj_img = self .add_k_proj (encoder_hidden_states_img )
1046+ check_nan_attn (key_proj_img , "key_proj_img" , tag )
10121047 with self .conditional_named_scope ("norm_add_k" ):
10131048 key_proj_img = self .norm_added_k (key_proj_img )
1049+ check_nan_attn (key_proj_img , "key_proj_img normed" , tag )
10141050 with self .conditional_named_scope ("add_proj_v" ):
10151051 value_proj_img = self .add_v_proj (encoder_hidden_states_img )
1052+ check_nan_attn (value_proj_img , "value_proj_img" , tag )
10161053
10171054 # Checkpointing
10181055 query_proj = checkpoint_name (query_proj , "query_proj" )
@@ -1024,19 +1061,25 @@ def __call__(
10241061 # Attention - tensors are (B, S, D)
10251062 with self .conditional_named_scope ("cross_attn_text_apply" ):
10261063 attn_output_text = self .attention_op .apply_attention (query_proj , key_proj_text , value_proj_text )
1064+ check_nan_attn (attn_output_text , "attn_output_text_h" , tag )
10271065 with self .conditional_named_scope ("norm_added_q" ):
10281066 query_proj_img = self .norm_added_q (query_proj )
1067+ check_nan_attn (query_proj_img , "query_proj_img normed" , tag )
10291068 with self .conditional_named_scope ("cross_attn_img_apply" ):
10301069 attn_output_img = self .attention_op .apply_attention (query_proj_img , key_proj_img , value_proj_img )
1070+ check_nan_attn (attn_output_img , "attn_output_img" , tag )
10311071
10321072 attn_output = attn_output_text + attn_output_img
1073+ check_nan_attn (attn_output , "attn_output final I2V" , tag )
10331074
10341075 attn_output = attn_output .astype (dtype = dtype )
10351076 attn_output = checkpoint_name (attn_output , "attn_output" )
10361077
10371078 with self .conditional_named_scope ("attn_out_proj" ):
10381079 hidden_states = self .proj_attn (attn_output )
1080+ check_nan_attn (hidden_states , "hidden_states after proj_attn" , tag )
10391081 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
1082+ check_nan_attn (hidden_states , "hidden_states after dropout" , tag )
10401083 return hidden_states
10411084
10421085
0 commit comments