|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Utilities for QK-Clip (Muon Clip).""" |
| 16 | + |
| 17 | +import jax |
| 18 | +import jax.numpy as jnp |
| 19 | + |
| 20 | + |
| 21 | +def _get_key_name(k): |
| 22 | + """Helper to unwrap JAX path keys.""" |
| 23 | + if hasattr(k, "key"): |
| 24 | + return k.key |
| 25 | + if hasattr(k, "idx"): |
| 26 | + return k.idx |
| 27 | + return k |
| 28 | + |
| 29 | + |
| 30 | +def calculate_max_logit_metric(intermediate_outputs): |
| 31 | + """Extracts and computes the global maximum logit from intermediate outputs. |
| 32 | +
|
| 33 | + Args: |
| 34 | + intermediate_outputs: A pytree containing model intermediates, potentially |
| 35 | + including 'max_logits' sowed by Attention layers. |
| 36 | +
|
| 37 | + Returns: |
| 38 | + The global maximum logit scalar, or None if no logits were found. |
| 39 | + """ |
| 40 | + all_max_logits = [] |
| 41 | + |
| 42 | + def extract_logits(path, val): |
| 43 | + # 'sow' stores values in a tuple/list. tree_map descends into it. |
| 44 | + # The path to the leaf array will look like: (..., 'max_logits', 0) |
| 45 | + # So we check if the parent key (path[-2]) is 'max_logits'. |
| 46 | + if len(path) >= 2: |
| 47 | + parent_key = _get_key_name(path[-2]) |
| 48 | + if parent_key == "max_logits": |
| 49 | + all_max_logits.append(val) |
| 50 | + |
| 51 | + jax.tree_util.tree_map_with_path(extract_logits, intermediate_outputs) |
| 52 | + |
| 53 | + if not all_max_logits: |
| 54 | + return None |
| 55 | + |
| 56 | + return jnp.max(jnp.stack(all_max_logits)) |
| 57 | + |
| 58 | + |
| 59 | +def apply_qk_clip(state, intermediate_outputs, config): |
| 60 | + """Applies QK-Clip to MLA weights based on max_logits. |
| 61 | +
|
| 62 | + Iterates over parameters. If a parameter belongs to an MLA attention layer, |
| 63 | + it finds the corresponding max_logits statistics from intermediate_outputs, |
| 64 | + calculates the clipping factor, and applies it to W_q and W_k components. |
| 65 | +
|
| 66 | + Args: |
| 67 | + state: The current training state containing model parameters. |
| 68 | + intermediate_outputs: A dictionary of intermediate outputs from the model |
| 69 | + forward pass. It is expected to contain 'max_logits' entries sowed by |
| 70 | + Attention layers if QK-Clip is enabled. |
| 71 | + config: The model configuration object, containing QK-Clip hyperparameters |
| 72 | + (e.g. qk_clip_threshold, qk_nope_head_dim) and attention_type. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + A new training state with updated (clipped) parameters. |
| 76 | +
|
| 77 | + Raises: |
| 78 | + ValueError: If the configured attention_type is not 'mla'. |
| 79 | + """ |
| 80 | + if getattr(config, "attention_type", None) != "mla": |
| 81 | + raise ValueError( |
| 82 | + f"QK-Clip is only supported for MLA attention (attention_type='mla'). " |
| 83 | + f"Current configuration: {getattr(config, 'attention_type', 'None')}" |
| 84 | + ) |
| 85 | + |
| 86 | + tau = float(config.qk_clip_threshold) |
| 87 | + |
| 88 | + def clip_mla_weights(path, param): |
| 89 | + """Applies QK-Clip to a single parameter if it's an MLA projection weight. |
| 90 | +
|
| 91 | + Args: |
| 92 | + path: A tuple of JAX Key objects representing the hierarchy path to the parameter in the state PyTree. |
| 93 | + param: The actual JAX array (weight tensor) at the given path. |
| 94 | +
|
| 95 | + Returns: |
| 96 | + The scaled parameter if it is an MLA projection ('wq_b' or 'wkv_b'), otherwise the original parameter. |
| 97 | + """ |
| 98 | + # Skip irrelevant weights (embeddings, norms, etc.). |
| 99 | + # We only care about specific MLA projection matrices ('wq_b', 'wkv_b'). |
| 100 | + if len(path) < 2: |
| 101 | + return param |
| 102 | + |
| 103 | + layer_name = _get_key_name(path[-2]) |
| 104 | + if layer_name not in ("wq_b", "wkv_b"): |
| 105 | + return param |
| 106 | + |
| 107 | + # Search for max_logits in intermediate_outputs |
| 108 | + curr = intermediate_outputs.get("intermediates", intermediate_outputs) |
| 109 | + for node in path[:-2]: |
| 110 | + key = _get_key_name(node) |
| 111 | + if isinstance(curr, dict) and key in curr: |
| 112 | + curr = curr[key] |
| 113 | + else: |
| 114 | + return param # Path not found in intermediates, skip |
| 115 | + |
| 116 | + if not isinstance(curr, dict) or "max_logits" not in curr: |
| 117 | + return param |
| 118 | + |
| 119 | + # max_logits was sowed as a tuple (array,) |
| 120 | + # shape: [batch, num_heads] |
| 121 | + max_logits_sowed = curr["max_logits"] |
| 122 | + if not max_logits_sowed: |
| 123 | + return param |
| 124 | + |
| 125 | + max_logits_batch = max_logits_sowed[0] |
| 126 | + |
| 127 | + # Calculate S_max (per head) |
| 128 | + # We want the global maximum across the batch dimension. |
| 129 | + # Result shape: [num_heads] |
| 130 | + s_max = jnp.max(max_logits_batch, axis=0) |
| 131 | + |
| 132 | + # Calculate scaling factor gamma |
| 133 | + # gamma = tau / s_max. Clip if s_max > tau. |
| 134 | + scale = jnp.minimum(1.0, tau / (s_max + 1e-6)) |
| 135 | + |
| 136 | + # Apply qk clipping based on weight type |
| 137 | + if layer_name == "wq_b": |
| 138 | + # MLA Up-projection for Query [rank, heads, q_head_dim] |
| 139 | + qk_nope = config.qk_nope_head_dim |
| 140 | + w_qc = param[..., :qk_nope] |
| 141 | + w_qr = param[..., qk_nope:] |
| 142 | + scale_b = scale[None, :, None] # Broadcast: [1, heads, 1] |
| 143 | + w_qc_new = w_qc * jnp.sqrt(scale_b) |
| 144 | + w_qr_new = w_qr * scale_b |
| 145 | + return jnp.concatenate([w_qc_new, w_qr_new], axis=-1) |
| 146 | + |
| 147 | + elif layer_name == "wkv_b": |
| 148 | + # MLA Up-projection for Key/Value [rank, heads, kv_head_dim] |
| 149 | + qk_nope = config.qk_nope_head_dim |
| 150 | + w_kc = param[..., :qk_nope] |
| 151 | + w_v = param[..., qk_nope:] |
| 152 | + scale_b = scale[None, :, None] |
| 153 | + w_kc_new = w_kc * jnp.sqrt(scale_b) |
| 154 | + return jnp.concatenate([w_kc_new, w_v], axis=-1) |
| 155 | + |
| 156 | + return param |
| 157 | + |
| 158 | + # Apply transformation |
| 159 | + new_params = jax.tree_util.tree_map_with_path(clip_mla_weights, state.params) |
| 160 | + return state.replace(params=new_params) |
0 commit comments