@@ -345,12 +345,14 @@ def __init__(
345345 dtype : DType = jnp .float32 ,
346346 attention_kernel : str = "flash" ,
347347 rope_type : str = "interleaved" ,
348+ enable_jax_named_scopes : bool = False ,
348349 ):
349350 self .heads = heads
350351 self .rope_type = rope_type
351352 self .dim_head = dim_head
352353 self .inner_dim = dim_head * heads
353354 self .dropout_rate = dropout
355+ self .enable_jax_named_scopes = enable_jax_named_scopes
354356
355357 # 1. Define Partitioned Initializers (Logical Axes)
356358 # Q, K, V kernels: [in_features (embed), out_features (heads)]
@@ -433,6 +435,11 @@ def __init__(
433435 axis_names_kv = ("batch" , "heads" , "length" , "kv" ),
434436 )
435437
438+ def conditional_named_scope (self , name : str ):
439+ import jax
440+ import contextlib
441+ return jax .named_scope (name ) if getattr (self , "enable_jax_named_scopes" , False ) else contextlib .nullcontext ()
442+
436443 def __call__ (
437444 self ,
438445 hidden_states : Array ,
@@ -445,13 +452,15 @@ def __call__(
445452 context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
446453
447454 # 1. Project
448- query = self .to_q (hidden_states )
449- key = self .to_k (context )
450- value = self .to_v (context )
455+ with self .conditional_named_scope ("proj_in" ):
456+ query = self .to_q (hidden_states )
457+ key = self .to_k (context )
458+ value = self .to_v (context )
451459
452460 # 2. Norm (Full Inner Dimension)
453- query = self .norm_q (query )
454- key = self .norm_k (key )
461+ with self .conditional_named_scope ("norm" ):
462+ query = self .norm_q (query )
463+ key = self .norm_k (key )
455464
456465 # 3. Apply RoPE to tensors of shape [B, S, InnerDim]
457466 # Frequencies are shape [B, S, InnerDim]
@@ -478,12 +487,14 @@ def __call__(
478487
479488 # 4. Attention
480489 # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
481- attn_output = self .attention_op .apply_attention (query = query , key = key , value = value , attention_mask = attention_mask )
490+ with self .conditional_named_scope ("attention_op" ):
491+ attn_output = self .attention_op .apply_attention (query = query , key = key , value = value , attention_mask = attention_mask )
482492
483493 # 7. Output Projection
484- hidden_states = self .to_out (attn_output )
485-
486- if self .dropout_layer is not None :
487- hidden_states = self .dropout_layer (hidden_states )
494+ with self .conditional_named_scope ("proj_out" ):
495+ hidden_states = self .to_out (attn_output )
496+
497+ if self .dropout_layer is not None :
498+ hidden_states = self .dropout_layer (hidden_states )
488499
489500 return hidden_states
0 commit comments