Skip to content

Commit 96b72fb

Browse files
Merge pull request #3283 from AI-Hypercomputer:agagik-shared-kv
PiperOrigin-RevId: 878275312
2 parents 3d6a5de + fd5ae5f commit 96b72fb

5 files changed

Lines changed: 107 additions & 3 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ param_scan_axis: 1
331331
# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
332332
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
333333
attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla
334+
share_kv_projections: False # Note: Not compatible with attention_type='mla'
334335
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
335336
attention_sink: False
336337
sliding_window_size: 0

src/maxtext/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ class Attention(BaseModel):
476476
"autoselected",
477477
description="The attention algorithm to use (dot_product, flash, etc).",
478478
)
479+
share_kv_projections: bool = Field(False, description="If True, Key and Value use the same projection.")
479480
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field(
480481
"global", description="The variant of attention to use."
481482
)
@@ -2505,6 +2506,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25052506
"Please disable attn_logits_soft_cap when using use_qk_clip."
25062507
)
25072508

2509+
if self.share_kv_projections and self.fused_qkv:
2510+
raise ValueError("`share_kv_projections` is not compatible with `fused_qkv`.")
2511+
if self.share_kv_projections and self.attention_type == "mla":
2512+
raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.")
2513+
25082514
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
25092515
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
25102516
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":

src/maxtext/layers/attentions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def attention_as_linen(
130130
use_qk_norm: bool = False,
131131
query_pre_attn_scalar: float | None = None,
132132
use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections
133+
share_kv_projections: bool = False, # If true, Key and Value use the same projection
133134
# Temperature tuning parameters used for Llama4
134135
temperature_tuning: bool = False,
135136
temperature_tuning_scale: float = 0.1,
@@ -199,6 +200,7 @@ def attention_as_linen(
199200
use_qk_norm=use_qk_norm,
200201
query_pre_attn_scalar=query_pre_attn_scalar,
201202
use_bias_in_projections=use_bias_in_projections,
203+
share_kv_projections=share_kv_projections,
202204
temperature_tuning=temperature_tuning,
203205
temperature_tuning_scale=temperature_tuning_scale,
204206
temperature_tuning_floor_scale=temperature_tuning_floor_scale,
@@ -295,6 +297,7 @@ def __init__(
295297
use_qk_norm: bool = False,
296298
query_pre_attn_scalar: float | None = None,
297299
use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections
300+
share_kv_projections: bool = False, # If true, Key and Value use the same projection
298301
# Temperature tuning parameters used for Llama4
299302
temperature_tuning: bool = False,
300303
temperature_tuning_scale: float = 0.1,
@@ -399,6 +402,7 @@ def __init__(
399402
self.use_qk_norm = use_qk_norm
400403
self.query_pre_attn_scalar = query_pre_attn_scalar
401404
self.use_bias_in_projections = use_bias_in_projections
405+
self.share_kv_projections = share_kv_projections
402406
self.temperature_tuning = temperature_tuning
403407
self.temperature_tuning_scale = temperature_tuning_scale
404408
self.temperature_tuning_floor_scale = temperature_tuning_floor_scale
@@ -559,7 +563,8 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
559563
else:
560564
self.query = self.init_query_w(inputs_q_shape=inputs_q_shape)
561565
self.key = self.init_kv_w(inputs_kv_shape=inputs_kv_shape)
562-
self.value = self.init_kv_w(inputs_kv_shape=inputs_kv_shape)
566+
if not self.share_kv_projections:
567+
self.value = self.init_kv_w(inputs_kv_shape=inputs_kv_shape)
563568
self.out = self.init_out_w(output_dim=inputs_q_shape[-1])
564569

565570
def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module:
@@ -1056,7 +1061,10 @@ def __call__(
10561061
else:
10571062
query = self.query_projection(inputs_q, out_sharding=qkv_sharding)
10581063
key = self.kv_projection(inputs_kv, proj_name="key", out_sharding=qkv_sharding)
1059-
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=qkv_sharding)
1064+
if self.share_kv_projections:
1065+
value = key
1066+
else:
1067+
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=qkv_sharding)
10601068

10611069
gate = None
10621070
if self.is_qwen3_next:

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __call__(
154154
reshape_q=cfg.reshape_q,
155155
use_mrope=cfg.use_mrope,
156156
mrope_section=cfg.mrope_section,
157+
share_kv_projections=cfg.share_kv_projections,
157158
model_mode=model_mode,
158159
)
159160

tests/unit/attention_test.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,15 @@ def test_tpu_kernel_attention_gqa(self):
473473
def test_tpu_kernel_attention_mqa(self):
474474
self.tpu_kernel_attention_helper(1)
475475

476-
def tpu_kernel_attention_helper(self, num_kv_heads):
476+
@pytest.mark.tpu_only
477+
def test_tpu_kernel_attention_mha_share_kv(self):
478+
self.tpu_kernel_attention_helper(self.num_kv_heads, share_kv_projections=True)
479+
480+
@pytest.mark.tpu_only
481+
def test_tpu_kernel_attention_gqa_share_kv(self):
482+
self.tpu_kernel_attention_helper(self.num_kv_heads // 2, share_kv_projections=True)
483+
484+
def tpu_kernel_attention_helper(self, num_kv_heads, share_kv_projections=False):
477485
"""Test equivalence between dot_product and TPU accelerated"""
478486

479487
lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype)
@@ -493,6 +501,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
493501
attention_kernel="dot_product",
494502
dtype=self.dtype,
495503
dropout_rate=self.cfg.dropout_rate,
504+
share_kv_projections=share_kv_projections,
496505
rngs=self.nnx_rng,
497506
)
498507

@@ -522,6 +531,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
522531
attention_kernel="flash",
523532
dtype=self.dtype,
524533
dropout_rate=self.cfg.dropout_rate,
534+
share_kv_projections=share_kv_projections,
525535
rngs=self.nnx_rng,
526536
)
527537
nnx.update(attention_as_mha_flash, generic_state)
@@ -539,6 +549,84 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
539549
jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False)
540550
)
541551

552+
def test_share_kv_projections(self):
553+
"""Test that kv projections are shared."""
554+
dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim))
555+
dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim))
556+
attention_share_kv = Attention(
557+
config=self.cfg,
558+
num_query_heads=self.num_query_heads,
559+
num_kv_heads=self.num_kv_heads,
560+
head_dim=self.head_dim,
561+
max_target_length=self.max_target_length,
562+
max_prefill_predict_length=self.cfg.max_prefill_predict_length,
563+
inputs_q_shape=dummy_inputs_q.shape,
564+
inputs_kv_shape=dummy_inputs_kv.shape,
565+
mesh=self.mesh,
566+
attention_kernel="dot_product",
567+
dtype=self.dtype,
568+
dropout_rate=self.cfg.dropout_rate,
569+
share_kv_projections=True,
570+
rngs=self.nnx_rng,
571+
)
572+
573+
self.assertFalse(hasattr(attention_share_kv, "value"))
574+
self.assertTrue(hasattr(attention_share_kv, "key"))
575+
576+
# 1. Check NNX state
577+
state_shared = nnx.state(attention_share_kv)
578+
self.assertNotIn("value", state_shared)
579+
self.assertIn("key", state_shared)
580+
581+
# 2. Forward Pass Verification
582+
lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype)
583+
584+
output_shared, _ = attention_share_kv(
585+
lnx,
586+
lnx,
587+
decoder_segment_ids=decoder_segment_ids,
588+
inputs_positions=decoder_positions,
589+
deterministic=True,
590+
model_mode=MODEL_MODE_TRAIN,
591+
)
592+
593+
self.assertEqual(output_shared.shape, (self.global_batch_size, self.max_target_length, self.embed_dim))
594+
595+
# 3. Equivalence Check with standard unshared Attention
596+
attention_no_share = Attention(
597+
config=self.cfg,
598+
num_query_heads=self.num_query_heads,
599+
num_kv_heads=self.num_kv_heads,
600+
head_dim=self.head_dim,
601+
max_target_length=self.max_target_length,
602+
max_prefill_predict_length=self.cfg.max_prefill_predict_length,
603+
inputs_q_shape=dummy_inputs_q.shape,
604+
inputs_kv_shape=dummy_inputs_kv.shape,
605+
mesh=self.mesh,
606+
attention_kernel="dot_product",
607+
dtype=self.dtype,
608+
dropout_rate=self.cfg.dropout_rate,
609+
share_kv_projections=False,
610+
rngs=self.nnx_rng,
611+
)
612+
613+
# Force unshared layer to copy weights from shared layer, mapping 'key' to 'value'
614+
attention_no_share.query.kernel.value = attention_share_kv.query.kernel.value
615+
attention_no_share.key.kernel.value = attention_share_kv.key.kernel.value
616+
attention_no_share.value.kernel.value = attention_share_kv.key.kernel.value
617+
attention_no_share.out.kernel.value = attention_share_kv.out.kernel.value
618+
619+
output_no_share, _ = attention_no_share(
620+
lnx,
621+
lnx,
622+
decoder_segment_ids=decoder_segment_ids,
623+
inputs_positions=decoder_positions,
624+
deterministic=True,
625+
model_mode=MODEL_MODE_TRAIN,
626+
)
627+
628+
self.assertTrue(jax.numpy.allclose(output_shared, output_no_share, rtol=1e-04, atol=1e-04, equal_nan=False))
629+
542630
@parameterized.named_parameters(
543631
{
544632
"testcase_name": "cp_no_load_balance",

0 commit comments

Comments
 (0)