Skip to content

Commit 992fd41

Browse files
gagikaGoogle-ML-Automation
authored andcommitted
PR #2968: Implement QK-Clip (Muon-Clip) for MLA attention
Imported from GitHub PR #2968 # Description Implements QK-Clip, a training stabilization technique for MLA attention models, as described in the Kimi K2 Technical Report. **Changes:** * **Core Logic:** Added `src/MaxText/utils/qk_clip_utils.py` containing `apply_qk_clip` and `calculate_max_logit_metric`. * **Layers:** Updated `AttentionOp` to `sow` max logits statistics and `AttentionMLA` to enable this when configured. * **Training:** Integrated the clipping step and `max_logits` metric reporting into `src/MaxText/train.py`. * **Tests:** Added `tests/qk_clip_test.py`. **Context:** QK-Clip mitigates training instability by preventing attention logits from growing excessively. This implementation: 1. Calculates global max logit ($S_{max}$) using GSPMD-compatible `jnp.max`. 2. Computes per-head scaling factor $\gamma = \min(1, \tau / S_{max})$. 3. Scales $W_q$ and $W_k$ while explicitly leaving shared rotary keys ($k^R$) and values ($W_v$) untouched. 4. Leverages Flax `sow` to pass statistics from layers to the training loop efficiently. # Tests * **Unit Tests:** Ran `python3 tests/qk_clip_test.py`. Verified: * Correct scaling of $W_q$ and $W_k$. * Heads below threshold are not clipped. * Shared keys and values remain untouched. * Global `max_logits` metric calculation. * Error handling for non-MLA attention types. * **Integration:** Verified `train_step` executes without shape mismatches or runtime errors. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- 0db9d22 by Gagik Amirkhanyan <agagik@google.com>: Implement QK-Clip (Muon-Clip) functionality add tests for QK-Clip logic Merging this change closes #2968 COPYBARA_INTEGRATE_REVIEW=#2968 from AI-Hypercomputer:agagik-qk-clip 0db9d22 PiperOrigin-RevId: 874946094
1 parent fc865f4 commit 992fd41

7 files changed

Lines changed: 836 additions & 51 deletions

File tree

src/MaxText/utils/qk_clip_utils.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for QK-Clip (Muon Clip)."""
16+
17+
import jax
18+
import jax.numpy as jnp
19+
20+
21+
def _get_key_name(k):
22+
"""Helper to unwrap JAX path keys."""
23+
if hasattr(k, "key"):
24+
return k.key
25+
if hasattr(k, "idx"):
26+
return k.idx
27+
return k
28+
29+
30+
def calculate_max_logit_metric(intermediate_outputs):
31+
"""Extracts and computes the global maximum logit from intermediate outputs.
32+
33+
Args:
34+
intermediate_outputs: A pytree containing model intermediates, potentially
35+
including 'max_logits' sowed by Attention layers.
36+
37+
Returns:
38+
The global maximum logit scalar, or None if no logits were found.
39+
"""
40+
all_max_logits = []
41+
42+
def extract_logits(path, val):
43+
# 'sow' stores values in a tuple/list. tree_map descends into it.
44+
# The path to the leaf array will look like: (..., 'max_logits', 0)
45+
# So we check if the parent key (path[-2]) is 'max_logits'.
46+
if len(path) >= 2:
47+
parent_key = _get_key_name(path[-2])
48+
if parent_key == "max_logits":
49+
all_max_logits.append(val)
50+
51+
jax.tree_util.tree_map_with_path(extract_logits, intermediate_outputs)
52+
53+
if not all_max_logits:
54+
return None
55+
56+
return jnp.max(jnp.stack(all_max_logits))
57+
58+
59+
def apply_qk_clip(state, intermediate_outputs, config):
60+
"""Applies QK-Clip to MLA weights based on max_logits.
61+
62+
Iterates over parameters. If a parameter belongs to an MLA attention layer,
63+
it finds the corresponding max_logits statistics from intermediate_outputs,
64+
calculates the clipping factor, and applies it to W_q and W_k components.
65+
66+
Args:
67+
state: The current training state containing model parameters.
68+
intermediate_outputs: A dictionary of intermediate outputs from the model
69+
forward pass. It is expected to contain 'max_logits' entries sowed by
70+
Attention layers if QK-Clip is enabled.
71+
config: The model configuration object, containing QK-Clip hyperparameters
72+
(e.g. qk_clip_threshold, qk_nope_head_dim) and attention_type.
73+
74+
Returns:
75+
A new training state with updated (clipped) parameters.
76+
77+
Raises:
78+
ValueError: If the configured attention_type is not 'mla'.
79+
"""
80+
if getattr(config, "attention_type", None) != "mla":
81+
raise ValueError(
82+
f"QK-Clip is only supported for MLA attention (attention_type='mla'). "
83+
f"Current configuration: {getattr(config, 'attention_type', 'None')}"
84+
)
85+
86+
tau = float(config.qk_clip_threshold)
87+
88+
def clip_mla_weights(path, param):
89+
"""Applies QK-Clip to a single parameter if it's an MLA projection weight.
90+
91+
Args:
92+
path: A tuple of JAX Key objects representing the hierarchy path to the parameter in the state PyTree.
93+
param: The actual JAX array (weight tensor) at the given path.
94+
95+
Returns:
96+
The scaled parameter if it is an MLA projection ('wq_b' or 'wkv_b'), otherwise the original parameter.
97+
"""
98+
# Skip irrelevant weights (embeddings, norms, etc.).
99+
# We only care about specific MLA projection matrices ('wq_b', 'wkv_b').
100+
if len(path) < 2:
101+
return param
102+
103+
layer_name = _get_key_name(path[-2])
104+
if layer_name not in ("wq_b", "wkv_b"):
105+
return param
106+
107+
# Search for max_logits in intermediate_outputs
108+
curr = intermediate_outputs.get("intermediates", intermediate_outputs)
109+
for node in path[:-2]:
110+
key = _get_key_name(node)
111+
if isinstance(curr, dict) and key in curr:
112+
curr = curr[key]
113+
else:
114+
return param # Path not found in intermediates, skip
115+
116+
if not isinstance(curr, dict) or "max_logits" not in curr:
117+
return param
118+
119+
# max_logits was sowed as a tuple (array,)
120+
# shape: [batch, num_heads]
121+
max_logits_sowed = curr["max_logits"]
122+
if not max_logits_sowed:
123+
return param
124+
125+
max_logits_batch = max_logits_sowed[0]
126+
127+
# Calculate S_max (per head)
128+
# We want the global maximum across the batch dimension.
129+
# Result shape: [num_heads]
130+
s_max = jnp.max(max_logits_batch, axis=0)
131+
132+
# Calculate scaling factor gamma
133+
# gamma = tau / s_max. Clip if s_max > tau.
134+
scale = jnp.minimum(1.0, tau / (s_max + 1e-6))
135+
136+
# Apply qk clipping based on weight type
137+
if layer_name == "wq_b":
138+
# MLA Up-projection for Query [rank, heads, q_head_dim]
139+
qk_nope = config.qk_nope_head_dim
140+
w_qc = param[..., :qk_nope]
141+
w_qr = param[..., qk_nope:]
142+
scale_b = scale[None, :, None] # Broadcast: [1, heads, 1]
143+
w_qc_new = w_qc * jnp.sqrt(scale_b)
144+
w_qr_new = w_qr * scale_b
145+
return jnp.concatenate([w_qc_new, w_qr_new], axis=-1)
146+
147+
elif layer_name == "wkv_b":
148+
# MLA Up-projection for Key/Value [rank, heads, kv_head_dim]
149+
qk_nope = config.qk_nope_head_dim
150+
w_kc = param[..., :qk_nope]
151+
w_v = param[..., qk_nope:]
152+
scale_b = scale[None, :, None]
153+
w_kc_new = w_kc * jnp.sqrt(scale_b)
154+
return jnp.concatenate([w_kc_new, w_v], axis=-1)
155+
156+
return param
157+
158+
# Apply transformation
159+
new_params = jax.tree_util.tree_map_with_path(clip_mla_weights, state.params)
160+
return state.replace(params=new_params)

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ qk_nope_head_dim: 128
355355
qk_rope_head_dim: 64
356356
v_head_dim: 128
357357

358+
# QK-Clip (Muon Clip) Configuration
359+
use_qk_clip: False # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash)
360+
qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
361+
358362
# Combine matmuls for QKV and MLP
359363
fused_qkv: False
360364
fused_mlp: False

src/maxtext/configs/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,8 @@ class Attention(BaseModel):
497497
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
498498
use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.")
499499
force_q_layout: bool = Field(False, description="Force the Q layout")
500+
use_qk_clip: bool = Field(False, description="Whether to use QK-Clip (MuonClip) for training stability.")
501+
qk_clip_threshold: float = Field(100.0, description="Threshold for QK-Clip (tau).")
500502

501503

502504
class MoBa(BaseModel):
@@ -2410,6 +2412,18 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24102412
if self.force_q_layout and not self.use_jax_splash:
24112413
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")
24122414

2415+
if self.use_qk_clip and self.attention_type != "mla":
2416+
raise ValueError(
2417+
f"QK-Clip is only supported when attention_type='mla', but found attention_type='{self.attention_type}'."
2418+
)
2419+
2420+
if self.use_qk_clip and self.attn_logits_soft_cap is not None:
2421+
raise ValueError(
2422+
"QK-Clip monitors raw dot products, but attn_logits_soft_cap is enabled. "
2423+
"Recording pre-cap max_logits is not fully supported yet. "
2424+
"Please disable attn_logits_soft_cap when using use_qk_clip."
2425+
)
2426+
24132427
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
24142428
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
24152429
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":

src/maxtext/layers/attention_mla.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,15 +1030,26 @@ def __call__(
10301030
attention_mask=attention_mask,
10311031
)
10321032

1033+
# Check if we need QK Clip stats
1034+
use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip
1035+
10331036
if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN:
10341037
unnormalized_out, _, exp_sum = self.ds_paged_attention_op(
10351038
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
10361039
)
10371040
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
10381041
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
10391042
else:
1040-
# Pass the index_mask to the Attention Op
1041-
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)
1043+
out = self.attention_op(
1044+
query,
1045+
key,
1046+
value,
1047+
decoder_segment_ids,
1048+
model_mode,
1049+
cached_values,
1050+
index_mask=index_mask,
1051+
record_max_logits=use_qk_clip,
1052+
)
10421053

10431054
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
10441055
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

0 commit comments

Comments
 (0)