5151def _maybe_aqt_einsum (quant : Quant ):
5252 return jnp .einsum if quant is None else quant .einsum ()
5353
54- def check_nan_attn (tensor : jax .Array , name : str , tag : str = "" ):
55- if tensor is None :
56- # This print is fine, it's not in JIT on None
57- print (f"[DEBUG ATTN PY { jax .process_index ()} ] { tag } { name } : Tensor is None" )
58- return
59-
60- # These are JAX boolean arrays (tracers when JITted)
61- has_nans = jnp .isnan (tensor ).any ()
62- has_infs = jnp .isinf (tensor ).any ()
63-
64- # Pass the tracers as keyword arguments to jax.debug.print
65- jax .debug .print (f"[DEBUG ATTN JIT { jax .process_index ()} ] { tag } { name } : "
66- "Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}" ,
67- shape = tensor .shape , has_nans_val = has_nans , has_infs_val = has_infs )
68-
6954
7055
7156def _check_attention_inputs (query : Array , key : Array , value : Array ) -> None :
@@ -961,13 +946,7 @@ def __call__(
961946 rotary_emb : Optional [jax .Array ] = None ,
962947 deterministic : bool = True ,
963948 rngs : nnx .Rngs = None ,
964- tag : str = "attn"
965949 ) -> jax .Array :
966- check_nan_attn (hidden_states , "Input hidden_states" , tag )
967- if encoder_hidden_states is not None :
968- check_nan_attn (encoder_hidden_states , "Input encoder_hidden_states" , tag )
969- if rotary_emb is not None :
970- check_nan_attn (rotary_emb , "Input rotary_emb" , tag )
971950
972951 hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
973952 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
@@ -982,79 +961,58 @@ def __call__(
982961 with self .conditional_named_scope ("attn_qkv_proj" ):
983962 with self .conditional_named_scope ("proj_query" ):
984963 query_proj = self .query (hidden_states )
985- check_nan_attn (query_proj , "query_proj" , tag )
986964 with self .conditional_named_scope ("proj_key" ):
987965 key_proj = self .key (encoder_hidden_states )
988- check_nan_attn (key_proj , "key_proj" , tag )
989966 with self .conditional_named_scope ("proj_value" ):
990967 value_proj = self .value (encoder_hidden_states )
991- check_nan_attn (value_proj , "value_proj" , tag )
992968
993969 if self .qk_norm :
994970 with self .conditional_named_scope ("attn_q_norm" ):
995971 query_proj = self .norm_q (query_proj )
996- check_nan_attn (query_proj , "query_proj normed" , tag )
997972 with self .conditional_named_scope ("attn_k_norm" ):
998973 key_proj = self .norm_k (key_proj )
999- check_nan_attn (key_proj , "key_proj normed" , tag )
1000974
1001975 if rotary_emb is not None : # Only for SELF-ATTENTION
1002976 with self .conditional_named_scope ("attn_rope" ):
1003977 # Unflatten is done HERE for RoPE
1004978 query_proj = _unflatten_heads (query_proj , self .heads )
1005- check_nan_attn (query_proj , "query_proj unflattened" , tag )
1006979 key_proj = _unflatten_heads (key_proj , self .heads )
1007- check_nan_attn (key_proj , "key_proj unflattened" , tag )
1008980 value_proj = _unflatten_heads (value_proj , self .heads )
1009- check_nan_attn (value_proj , "value_proj unflattened" , tag )
1010981 # output of _unflatten_heads Batch, heads, seq_len, head_dim
1011982 query_proj , key_proj = self ._apply_rope (query_proj , key_proj , rotary_emb )
1012- check_nan_attn (query_proj , "query_proj after RoPE" , tag )
1013- check_nan_attn (key_proj , "key_proj after RoPE" , tag )
1014983 query_proj = checkpoint_name (query_proj , "query_proj" )
1015984 key_proj = checkpoint_name (key_proj , "key_proj" )
1016985 value_proj = checkpoint_name (value_proj , "value_proj" )
1017986 with self .conditional_named_scope ("attn_compute" ):
1018987 attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
1019- check_nan_attn (attn_output , "attn_output from attention_op" , tag )
1020988
1021989 else :
1022990 # NEW PATH for I2V CROSS-ATTENTION
1023991 with self .conditional_named_scope ("proj_query" ):
1024992 query_proj = self .query (hidden_states )
1025- check_nan_attn (query_proj , "query_proj I2V" , tag )
1026993 if self .qk_norm :
1027994 with self .conditional_named_scope ("attn_q_norm" ):
1028995 query_proj = self .norm_q (query_proj )
1029- check_nan_attn (query_proj , "query_proj normed I2V" , tag )
1030996
1031997 encoder_hidden_states_img = encoder_hidden_states [:, :self .image_seq_len , :]
1032998 encoder_hidden_states_text = encoder_hidden_states [:, self .image_seq_len :, :]
1033- check_nan_attn (encoder_hidden_states_img , "EHS_img" , tag )
1034- check_nan_attn (encoder_hidden_states_text , "EHS_text" , tag )
1035999
10361000 # Text K/V
10371001 with self .conditional_named_scope ("proj_key" ):
10381002 key_proj_text = self .key (encoder_hidden_states_text )
1039- check_nan_attn (key_proj_text , "key_proj_text" , tag )
10401003 if self .qk_norm :
10411004 with self .conditional_named_scope ("attn_k_norm" ):
10421005 key_proj_text = self .norm_k (key_proj_text )
1043- check_nan_attn (key_proj_text , "key_proj_text normed" , tag )
10441006 with self .conditional_named_scope ("proj_value" ):
10451007 value_proj_text = self .value (encoder_hidden_states_text )
1046- check_nan_attn (value_proj_text , "value_proj_text" , tag )
10471008
10481009 # Image K/V
10491010 with self .conditional_named_scope ("add_proj_k" ):
10501011 key_proj_img = self .add_k_proj (encoder_hidden_states_img )
1051- check_nan_attn (key_proj_img , "key_proj_img" , tag )
10521012 with self .conditional_named_scope ("norm_add_k" ):
10531013 key_proj_img = self .norm_added_k (key_proj_img )
1054- check_nan_attn (key_proj_img , "key_proj_img normed" , tag )
10551014 with self .conditional_named_scope ("add_proj_v" ):
10561015 value_proj_img = self .add_v_proj (encoder_hidden_states_img )
1057- check_nan_attn (value_proj_img , "value_proj_img" , tag )
10581016
10591017 # Checkpointing
10601018 query_proj = checkpoint_name (query_proj , "query_proj" )
@@ -1066,25 +1024,19 @@ def __call__(
10661024 # Attention - tensors are (B, S, D)
10671025 with self .conditional_named_scope ("cross_attn_text_apply" ):
10681026 attn_output_text = self .attention_op .apply_attention (query_proj , key_proj_text , value_proj_text )
1069- check_nan_attn (attn_output_text , "attn_output_text_h" , tag )
10701027 with self .conditional_named_scope ("norm_added_q" ):
10711028 query_proj_img = self .norm_added_q (query_proj )
1072- check_nan_attn (query_proj_img , "query_proj_img normed" , tag )
10731029 with self .conditional_named_scope ("cross_attn_img_apply" ):
10741030 attn_output_img = self .attention_op .apply_attention (query_proj_img , key_proj_img , value_proj_img )
1075- check_nan_attn (attn_output_img , "attn_output_img" , tag )
10761031
10771032 attn_output = attn_output_text + attn_output_img
1078- check_nan_attn (attn_output , "attn_output final I2V" , tag )
10791033
10801034 attn_output = attn_output .astype (dtype = dtype )
10811035 attn_output = checkpoint_name (attn_output , "attn_output" )
10821036
10831037 with self .conditional_named_scope ("attn_out_proj" ):
10841038 hidden_states = self .proj_attn (attn_output )
1085- check_nan_attn (hidden_states , "hidden_states after proj_attn" , tag )
10861039 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
1087- check_nan_attn (hidden_states , "hidden_states after dropout" , tag )
10881040 return hidden_states
10891041
10901042
0 commit comments