|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +"""Utilities for Muon optimizer integration and dimension number generation. |
| 17 | +
|
| 18 | +This module provides functions to automatically generate MuonDimensionNumbers |
| 19 | +for various MaxText models. These dimension numbers are crucial for the Muon |
| 20 | +optimizer to correctly apply its update rules. |
| 21 | +
|
| 22 | +This module can also be run as a script to inspect the generated dimension |
| 23 | +numbers for a specific model. Example: |
| 24 | + python3 -m MaxText.muon_utils qwen3-4b True |
| 25 | +""" |
| 26 | + |
| 27 | + |
| 28 | +import os |
| 29 | +import sys |
| 30 | +from typing import Optional, Tuple |
| 31 | + |
| 32 | +import flax.linen as nn |
| 33 | +import jax |
| 34 | +from optax.contrib._muon import MuonDimensionNumbers as mdn |
| 35 | + |
| 36 | +from MaxText import maxtext_utils, pyconfig |
| 37 | +from MaxText.globals import MAXTEXT_PKG_DIR |
| 38 | +from MaxText.layers import models, quantizations |
| 39 | + |
| 40 | + |
| 41 | +Transformer = models.transformer_as_linen |
| 42 | + |
| 43 | + |
| 44 | +def _is_path_contain_any(tuples, path): |
| 45 | + return any(x in path for x in tuples) |
| 46 | + |
| 47 | + |
| 48 | +def transform_logic(path: Tuple[str, ...]) -> Optional[mdn]: |
| 49 | + """ |
| 50 | + Determines Muon dimension numbers based on the parameter's hierarchical path. |
| 51 | +
|
| 52 | + This function defines the mapping from a parameter's logical path within the model |
| 53 | + to its corresponding MuonDimensionNumbers (mdn). The strategy is applied in |
| 54 | + a specific order to handle general cases and then more specific ones, allowing |
| 55 | + for fall-through logic in nested structures. |
| 56 | +
|
| 57 | + Strategy: |
| 58 | + 1. Exclusions: Parameters not suitable for Muon (e.g., scalars, embeddings, |
| 59 | + unembedding) are explicitly returned as `None`. |
| 60 | + 2. Special Weights: |
| 61 | + 2.1 MoE Block Specific Weights |
| 62 | + 2.2 Self-Attention Specific Weights |
| 63 | + 3. Standard Weights: Default mapping for most other 3D weight shapes. |
| 64 | +
|
| 65 | + Args: |
| 66 | + path: A tuple of strings representing the hierarchical path of the parameter. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + An instance of `MuonDimensionNumbers` if a specific mapping is found, |
| 70 | + `None` for excluded parameters, or a default `mdn` for standard weights. |
| 71 | + """ |
| 72 | + |
| 73 | + # 1 Exclude parameters not suitable for Muon (scalar, embeddings, unembedding) |
| 74 | + if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense"), path): |
| 75 | + return None |
| 76 | + |
| 77 | + # 2 Special weights |
| 78 | + # 2.1 Special weights: MoE, [0, L, -2, -1] |
| 79 | + # L (optional) stands for layer when scan_layers=True |
| 80 | + if "MoeBlock_0" in path: |
| 81 | + # exclude gate |
| 82 | + if _is_path_contain_any(("wi_0", "wi_1", "wo"), path): |
| 83 | + return mdn((-2,), (-1,)) |
| 84 | + |
| 85 | + # 2.2 Special weights: Self attention |
| 86 | + elif "self_attention" in path: |
| 87 | + # Attention output projection: [0, L, -2, -1] |
| 88 | + if "out" in path: |
| 89 | + return mdn((0, -2), (-1,)) |
| 90 | + # Attention qkv projection: [0, L, -2, -1] |
| 91 | + # MLA, exclude wq_a / wkv_a |
| 92 | + elif _is_path_contain_any(("query", "key", "value", "wq_b", "wkv_b"), path): |
| 93 | + return mdn((0,), (-2, -1)) |
| 94 | + |
| 95 | + # 3 Standard weights, [0, L, -1] |
| 96 | + return mdn((0,), (-1,)) |
| 97 | + |
| 98 | + |
| 99 | +def get_transform_tree(tree, path=()): |
| 100 | + """Extraction utility via recursion.""" |
| 101 | + if isinstance(tree, dict): |
| 102 | + return {k: get_transform_tree(v, path + (k,)) for k, v in tree.items()} |
| 103 | + else: |
| 104 | + return transform_logic(path) |
| 105 | + |
| 106 | + |
| 107 | +def get_muon_weight_dimension_numbers(model, config, verbose=False): |
| 108 | + """Extract muon dimension number from model structure.""" |
| 109 | + # quickly get param structure without materialization |
| 110 | + abstract_param = maxtext_utils.get_abstract_param(model, config) |
| 111 | + # get muon dimension number from param |
| 112 | + muon_weight_dimension_numbers = get_transform_tree(abstract_param) |
| 113 | + if verbose: |
| 114 | + _print_structure_debug(abstract_param, muon_weight_dimension_numbers) |
| 115 | + return muon_weight_dimension_numbers |
| 116 | + |
| 117 | + |
| 118 | +def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): |
| 119 | + """Prints the model structure and the resulting Muon config.""" |
| 120 | + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper |
| 121 | + # Return a new tree with the same structure containing only shapes/names |
| 122 | + info_tree = jax.tree_util.tree_map( |
| 123 | + lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, |
| 124 | + abstract_param, |
| 125 | + is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), |
| 126 | + ) |
| 127 | + print(f"\n=== Model Structure ===\n{info_tree}") |
| 128 | + print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") |
| 129 | + print("\nIs this reasonable?") |
| 130 | + |
| 131 | + |
| 132 | +def get_model_mdn(model_name, scan_layers=True, verbose=False): |
| 133 | + """Initializes a model and retrieves its Muon dimension numbers. |
| 134 | +
|
| 135 | + This function sets up the configuration for a given model, initializes the |
| 136 | + transformer model, and then extracts the Muon dimension numbers for the model's |
| 137 | + weights. It can optionally print verbose debug information. |
| 138 | +
|
| 139 | + Args: |
| 140 | + model_name: The name of the model to be initialized. |
| 141 | + scan_layers: Whether to use layer scanning in the model configuration. |
| 142 | + verbose: If True, prints detailed debugging information about the model |
| 143 | + structure and Muon dimension numbers. |
| 144 | +
|
| 145 | + Returns: |
| 146 | + A tree structure containing the Muon dimension numbers for the model's |
| 147 | + parameters. |
| 148 | + """ |
| 149 | + # Setup config |
| 150 | + argv = [ |
| 151 | + None, |
| 152 | + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), |
| 153 | + f"model_name={model_name}", |
| 154 | + f"scan_layers={scan_layers}", |
| 155 | + "attention=dot_product", |
| 156 | + ] |
| 157 | + config = pyconfig.initialize(argv) |
| 158 | + # Setup model |
| 159 | + devices_array = maxtext_utils.create_device_mesh(config) |
| 160 | + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) |
| 161 | + quant = quantizations.configure_quantization(config) |
| 162 | + model = Transformer(config, mesh=mesh, quant=quant) |
| 163 | + # Get dimension number |
| 164 | + muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) |
| 165 | + return muon_weight_dimension_numbers |
| 166 | + |
| 167 | + |
| 168 | +if __name__ == "__main__": |
| 169 | + if len(sys.argv) != 3: |
| 170 | + print("Usage: python3 -m MaxText.muon_utils <model_name> <scan_layers:True/False>") |
| 171 | + sys.exit(1) |
| 172 | + model_name_arg = sys.argv[1] |
| 173 | + scan_layers_arg = sys.argv[2].lower() == "true" |
| 174 | + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) |
0 commit comments