@@ -473,7 +473,15 @@ def test_tpu_kernel_attention_gqa(self):
473473 def test_tpu_kernel_attention_mqa (self ):
474474 self .tpu_kernel_attention_helper (1 )
475475
476- def tpu_kernel_attention_helper (self , num_kv_heads ):
476+ @pytest .mark .tpu_only
477+ def test_tpu_kernel_attention_mha_share_kv (self ):
478+ self .tpu_kernel_attention_helper (self .num_kv_heads , share_kv_projections = True )
479+
480+ @pytest .mark .tpu_only
481+ def test_tpu_kernel_attention_gqa_share_kv (self ):
482+ self .tpu_kernel_attention_helper (self .num_kv_heads // 2 , share_kv_projections = True )
483+
484+ def tpu_kernel_attention_helper (self , num_kv_heads , share_kv_projections = False ):
477485 """Test equivalence between dot_product and TPU accelerated"""
478486
479487 lnx , decoder_segment_ids , decoder_positions = self .get_data (self .dtype )
@@ -493,6 +501,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
493501 attention_kernel = "dot_product" ,
494502 dtype = self .dtype ,
495503 dropout_rate = self .cfg .dropout_rate ,
504+ share_kv_projections = share_kv_projections ,
496505 rngs = self .nnx_rng ,
497506 )
498507
@@ -522,6 +531,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
522531 attention_kernel = "flash" ,
523532 dtype = self .dtype ,
524533 dropout_rate = self .cfg .dropout_rate ,
534+ share_kv_projections = share_kv_projections ,
525535 rngs = self .nnx_rng ,
526536 )
527537 nnx .update (attention_as_mha_flash , generic_state )
@@ -539,6 +549,84 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
539549 jax .numpy .allclose (mha_generic_output , mha_generic_flash_output , rtol = 1e-01 , atol = 1e-01 , equal_nan = False )
540550 )
541551
552+ def test_share_kv_projections (self ):
553+ """Test that kv projections are shared."""
554+ dummy_inputs_q = jnp .ones ((self .global_batch_size , self .max_target_length , self .embed_dim ))
555+ dummy_inputs_kv = jnp .ones ((self .global_batch_size , self .max_target_length , self .embed_dim ))
556+ attention_share_kv = Attention (
557+ config = self .cfg ,
558+ num_query_heads = self .num_query_heads ,
559+ num_kv_heads = self .num_kv_heads ,
560+ head_dim = self .head_dim ,
561+ max_target_length = self .max_target_length ,
562+ max_prefill_predict_length = self .cfg .max_prefill_predict_length ,
563+ inputs_q_shape = dummy_inputs_q .shape ,
564+ inputs_kv_shape = dummy_inputs_kv .shape ,
565+ mesh = self .mesh ,
566+ attention_kernel = "dot_product" ,
567+ dtype = self .dtype ,
568+ dropout_rate = self .cfg .dropout_rate ,
569+ share_kv_projections = True ,
570+ rngs = self .nnx_rng ,
571+ )
572+
573+ self .assertFalse (hasattr (attention_share_kv , "value" ))
574+ self .assertTrue (hasattr (attention_share_kv , "key" ))
575+
576+ # 1. Check NNX state
577+ state_shared = nnx .state (attention_share_kv )
578+ self .assertNotIn ("value" , state_shared )
579+ self .assertIn ("key" , state_shared )
580+
581+ # 2. Forward Pass Verification
582+ lnx , decoder_segment_ids , decoder_positions = self .get_data (self .dtype )
583+
584+ output_shared , _ = attention_share_kv (
585+ lnx ,
586+ lnx ,
587+ decoder_segment_ids = decoder_segment_ids ,
588+ inputs_positions = decoder_positions ,
589+ deterministic = True ,
590+ model_mode = MODEL_MODE_TRAIN ,
591+ )
592+
593+ self .assertEqual (output_shared .shape , (self .global_batch_size , self .max_target_length , self .embed_dim ))
594+
595+ # 3. Equivalence Check with standard unshared Attention
596+ attention_no_share = Attention (
597+ config = self .cfg ,
598+ num_query_heads = self .num_query_heads ,
599+ num_kv_heads = self .num_kv_heads ,
600+ head_dim = self .head_dim ,
601+ max_target_length = self .max_target_length ,
602+ max_prefill_predict_length = self .cfg .max_prefill_predict_length ,
603+ inputs_q_shape = dummy_inputs_q .shape ,
604+ inputs_kv_shape = dummy_inputs_kv .shape ,
605+ mesh = self .mesh ,
606+ attention_kernel = "dot_product" ,
607+ dtype = self .dtype ,
608+ dropout_rate = self .cfg .dropout_rate ,
609+ share_kv_projections = False ,
610+ rngs = self .nnx_rng ,
611+ )
612+
613+ # Force unshared layer to copy weights from shared layer, mapping 'key' to 'value'
614+ attention_no_share .query .kernel .value = attention_share_kv .query .kernel .value
615+ attention_no_share .key .kernel .value = attention_share_kv .key .kernel .value
616+ attention_no_share .value .kernel .value = attention_share_kv .key .kernel .value
617+ attention_no_share .out .kernel .value = attention_share_kv .out .kernel .value
618+
619+ output_no_share , _ = attention_no_share (
620+ lnx ,
621+ lnx ,
622+ decoder_segment_ids = decoder_segment_ids ,
623+ inputs_positions = decoder_positions ,
624+ deterministic = True ,
625+ model_mode = MODEL_MODE_TRAIN ,
626+ )
627+
628+ self .assertTrue (jax .numpy .allclose (output_shared , output_no_share , rtol = 1e-04 , atol = 1e-04 , equal_nan = False ))
629+
542630 @parameterized .named_parameters (
543631 {
544632 "testcase_name" : "cp_no_load_balance" ,
0 commit comments