4040
4141BlockSizes = common_types .BlockSizes
4242
43+ def check_nan (tensor : jax .Array , name : str ):
44+ if jnp .isnan (tensor ).any ():
45+ print (f"[DEBUG NaN Check] NaNs detected in { name } on process { jax .process_index ()} " )
4346
4447def get_frequencies (max_seq_len : int , theta : int , attention_head_dim : int , use_real : bool ):
4548 h_dim = w_dim = 2 * (attention_head_dim // 6 )
@@ -373,8 +376,15 @@ def __call__(
373376 deterministic : bool = True ,
374377 rngs : nnx .Rngs = None ,
375378 ):
379+ check_nan (hidden_states , "TransformerBlock Input hidden_states" )
380+ check_nan (encoder_hidden_states , "TransformerBlock Input encoder_hidden_states" )
381+ check_nan (temb , "TransformerBlock Input temb" )
382+ if rotary_emb is not None :
383+ check_nan (rotary_emb , "TransformerBlock Input rotary_emb" )
376384 with self .conditional_named_scope ("transformer_block" ):
377385 with self .conditional_named_scope ("adaln" ):
386+ scale_shift_all = (self .adaln_scale_shift_table .value + temb .astype (jnp .float32 ))
387+ check_nan (scale_shift_all , "AdaLN scale_shift_all" )
378388 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
379389 (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
380390 )
@@ -385,9 +395,12 @@ def __call__(
385395 # 1. Self-attention
386396 with self .conditional_named_scope ("self_attn" ):
387397 with self .conditional_named_scope ("self_attn_norm" ):
388- norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
398+ norm_hidden_states = self .norm1 (hidden_states .astype (jnp .float32 ))
399+ check_nan (norm_hidden_states , "Self-Attn norm1 output" )
400+ norm_hidden_states = (norm_hidden_states * (1 + scale_msa ) + shift_msa ).astype (
389401 hidden_states .dtype
390402 )
403+ check_nan (norm_hidden_states , "Self-Attn norm_hidden_states after AdaLN" )
391404 with self .conditional_named_scope ("self_attn_attn" ):
392405 attn_output = self .attn1 (
393406 hidden_states = norm_hidden_states ,
@@ -396,28 +409,42 @@ def __call__(
396409 deterministic = deterministic ,
397410 rngs = rngs ,
398411 )
412+ check_nan (attn_output , "Self-Attn attn_output (attn1)" )
399413 with self .conditional_named_scope ("self_attn_residual" ):
400414 hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
415+ check_nan (hidden_states , "Self-Attn hidden_states after residual" )
401416
402417 # 2. Cross-attention
418+ residual = hidden_states
403419 norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
420+ check_nan (norm_hidden_states , "Cross-Attn norm_hidden_states (norm2)" )
404421 attn_output = self .attn2 (
405422 hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
406423 )
407- hidden_states = hidden_states + attn_output
424+ check_nan (attn_output , "Cross-Attn attn_output (attn2)" )
425+ hidden_states = residual + attn_output
426+ check_nan (hidden_states , "Cross-Attn hidden_states after residual" )
408427
409428 # 3. Feed-forward
429+ residual = hidden_states
410430 with self .conditional_named_scope ("mlp" ):
411431 with self .conditional_named_scope ("mlp_norm" ):
412- norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
432+ norm_hidden_states = self .norm3 (hidden_states .astype (jnp .float32 ))
433+ check_nan (norm_hidden_states , "MLP norm3 output" )
434+ norm_hidden_states = (norm_hidden_states * (1 + c_scale_msa ) + c_shift_msa ).astype (
413435 hidden_states .dtype
414436 )
437+ check_nan (norm_hidden_states , "MLP norm_hidden_states after AdaLN" )
438+
415439 with self .conditional_named_scope ("mlp_ffn" ):
416440 ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
441+ check_nan (ff_output , "MLP ff_output" )
442+
417443 with self .conditional_named_scope ("mlp_residual" ):
418444 hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
419445 hidden_states .dtype
420446 )
447+ check_nan (hidden_states , "MLP hidden_states after residual (Block Output)" )
421448 return hidden_states
422449
423450
0 commit comments