Skip to content

Commit 948d302

Browse files
Merge pull request #2546 from AI-Hypercomputer:shuningjin-opt6
PiperOrigin-RevId: 847904636
2 parents 9603b65 + a4650cf commit 948d302

12 files changed

Lines changed: 506 additions & 27 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44+
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
4445
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
4546
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) are available.
4647
* \[December 4, 2025\] The [ReadTheDocs documentation site](https://maxtext.readthedocs.io/en/latest/index.html) has been reorganized.

src/MaxText/configs/base.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ gradient_clipping_threshold: 1.0
704704
# batch by accumulating the gradient over a set of steps.
705705
gradient_accumulation_steps: 1
706706

707-
opt_type: "adamw" # one of "adamw", "adam_pax" or "sgd"
707+
opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"
708708

709709
# AdamW optimizer parameters
710710
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
@@ -717,6 +717,14 @@ mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inher
717717
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
718718
# See b/399961932 for more.
719719

720+
# Muon optimizer parameters
721+
# https://github.com/google-deepmind/optax/blob/main/optax/contrib/_muon.py
722+
# "mu_dtype", "adam_eps" are shared by AdamW
723+
# "nesterov", "ns_coeffs", "ns_steps", "weight_decay_mask", "adaptive" use default
724+
muon_beta: 0.95 # Decay rate for the exponentially weighted average of grads.
725+
muon_weight_decay: 0 # Strength of the weight decay regularization. This is multiplied with the learning rate.
726+
muon_consistent_rms: None # If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).
727+
720728
# Stack trace parameters
721729
collect_stack_trace: False
722730
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.

src/MaxText/configs/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class OptimizerType(str, Enum):
121121
ADAMW = "adamw"
122122
ADAM_PAX = "adam_pax"
123123
SGD = "sgd"
124+
MUON = "muon"
124125

125126

126127
class RopeType(str, Enum):
@@ -1041,6 +1042,18 @@ class AdamW(BaseModel):
10411042
)
10421043

10431044

1045+
class Muon(BaseModel):
1046+
"""Configuration specific to the Muon optimizer."""
1047+
1048+
muon_beta: float = Field(0.95, description="Decay rate for the exponentially weighted average of grads.")
1049+
muon_weight_decay: float = Field(
1050+
0, description="Strength of the weight decay regularization. This is multiplied with the learning rate."
1051+
)
1052+
muon_consistent_rms: None | float = Field(
1053+
None, description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2)."
1054+
)
1055+
1056+
10441057
class PositionalEmbedding(BaseModel):
10451058
"""General configuration for positional embeddings."""
10461059

@@ -1618,6 +1631,7 @@ class MaxTextConfig(
16181631
TrainingLoop,
16191632
Optimizer,
16201633
AdamW,
1634+
Muon,
16211635
FineTuning,
16221636
# Reinforcement Learning
16231637
RLHardware,
@@ -2119,6 +2133,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21192133
self.use_grpo = True
21202134
else:
21212135
self.use_grpo = False
2136+
if self.opt_type == "muon" and self.decoder_block not in [
2137+
DecoderBlockType.DEEPSEEK,
2138+
DecoderBlockType.QWEN3,
2139+
DecoderBlockType.GEMMA3,
2140+
DecoderBlockType.LLAMA2,
2141+
]:
2142+
raise ValueError(
2143+
"Muon dimension numbers haven't been tested for this model. Run this command first: "
2144+
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
2145+
)
21222146

21232147
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
21242148
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.

src/MaxText/maxtext_utils.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -752,18 +752,19 @@ def init_initial_state(model, tx, config, is_training, key):
752752

753753
def get_abstract_param(model, config):
754754
"""Get abstract model structure (name, shape) without materializing the weights to save memory"""
755-
key = jax.random.PRNGKey(0)
756-
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
757-
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
758-
config.model_name, batch_size=config.micro_batch_size_to_train_on
759-
)
760-
abstract_vars = jax.eval_shape(
761-
model.init,
762-
{"params": key, "dropout": key, "aqt": key},
763-
jnp.ones(input_shape, dtype=jnp.int32),
764-
jnp.ones(input_shape, dtype=jnp.int32),
765-
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
766-
)
755+
with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
756+
key = jax.random.PRNGKey(0)
757+
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
758+
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
759+
config.model_name, batch_size=config.micro_batch_size_to_train_on
760+
)
761+
abstract_vars = jax.eval_shape(
762+
model.init,
763+
{"params": key, "dropout": key, "aqt": key},
764+
jnp.ones(input_shape, dtype=jnp.int32),
765+
jnp.ones(input_shape, dtype=jnp.int32),
766+
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
767+
)
767768
return abstract_vars
768769

769770

src/MaxText/muon_utils.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Utilities for Muon optimizer integration and dimension number generation.
17+
18+
This module provides functions to automatically generate MuonDimensionNumbers
19+
for various MaxText models. These dimension numbers are crucial for the Muon
20+
optimizer to correctly apply its update rules.
21+
22+
This module can also be run as a script to inspect the generated dimension
23+
numbers for a specific model. Example:
24+
python3 -m MaxText.muon_utils qwen3-4b True
25+
"""
26+
27+
28+
import os
29+
import sys
30+
from typing import Optional, Tuple
31+
32+
import flax.linen as nn
33+
import jax
34+
from optax.contrib._muon import MuonDimensionNumbers as mdn
35+
36+
from MaxText import maxtext_utils, pyconfig
37+
from MaxText.globals import MAXTEXT_PKG_DIR
38+
from MaxText.layers import models, quantizations
39+
40+
41+
Transformer = models.transformer_as_linen
42+
43+
44+
def _is_path_contain_any(tuples, path):
45+
return any(x in path for x in tuples)
46+
47+
48+
def transform_logic(path: Tuple[str, ...]) -> Optional[mdn]:
49+
"""
50+
Determines Muon dimension numbers based on the parameter's hierarchical path.
51+
52+
This function defines the mapping from a parameter's logical path within the model
53+
to its corresponding MuonDimensionNumbers (mdn). The strategy is applied in
54+
a specific order to handle general cases and then more specific ones, allowing
55+
for fall-through logic in nested structures.
56+
57+
Strategy:
58+
1. Exclusions: Parameters not suitable for Muon (e.g., scalars, embeddings,
59+
unembedding) are explicitly returned as `None`.
60+
2. Special Weights:
61+
2.1 MoE Block Specific Weights
62+
2.2 Self-Attention Specific Weights
63+
3. Standard Weights: Default mapping for most other 3D weight shapes.
64+
65+
Args:
66+
path: A tuple of strings representing the hierarchical path of the parameter.
67+
68+
Returns:
69+
An instance of `MuonDimensionNumbers` if a specific mapping is found,
70+
`None` for excluded parameters, or a default `mdn` for standard weights.
71+
"""
72+
73+
# 1 Exclude parameters not suitable for Muon (scalar, embeddings, unembedding)
74+
if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense"), path):
75+
return None
76+
77+
# 2 Special weights
78+
# 2.1 Special weights: MoE, [0, L, -2, -1]
79+
# L (optional) stands for layer when scan_layers=True
80+
if "MoeBlock_0" in path:
81+
# exclude gate
82+
if _is_path_contain_any(("wi_0", "wi_1", "wo"), path):
83+
return mdn((-2,), (-1,))
84+
85+
# 2.2 Special weights: Self attention
86+
elif "self_attention" in path:
87+
# Attention output projection: [0, L, -2, -1]
88+
if "out" in path:
89+
return mdn((0, -2), (-1,))
90+
# Attention qkv projection: [0, L, -2, -1]
91+
# MLA, exclude wq_a / wkv_a
92+
elif _is_path_contain_any(("query", "key", "value", "wq_b", "wkv_b"), path):
93+
return mdn((0,), (-2, -1))
94+
95+
# 3 Standard weights, [0, L, -1]
96+
return mdn((0,), (-1,))
97+
98+
99+
def get_transform_tree(tree, path=()):
100+
"""Extraction utility via recursion."""
101+
if isinstance(tree, dict):
102+
return {k: get_transform_tree(v, path + (k,)) for k, v in tree.items()}
103+
else:
104+
return transform_logic(path)
105+
106+
107+
def get_muon_weight_dimension_numbers(model, config, verbose=False):
108+
"""Extract muon dimension number from model structure."""
109+
# quickly get param structure without materialization
110+
abstract_param = maxtext_utils.get_abstract_param(model, config)
111+
# get muon dimension number from param
112+
muon_weight_dimension_numbers = get_transform_tree(abstract_param)
113+
if verbose:
114+
_print_structure_debug(abstract_param, muon_weight_dimension_numbers)
115+
return muon_weight_dimension_numbers
116+
117+
118+
def _print_structure_debug(abstract_param, muon_weight_dimension_numbers):
119+
"""Prints the model structure and the resulting Muon config."""
120+
# Access the shape from the inner ShapeDtypeStruct and names from the wrapper
121+
# Return a new tree with the same structure containing only shapes/names
122+
info_tree = jax.tree_util.tree_map(
123+
lambda leaf: {"shape": leaf.value.shape, "names": leaf.names},
124+
abstract_param,
125+
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
126+
)
127+
print(f"\n=== Model Structure ===\n{info_tree}")
128+
print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}")
129+
print("\nIs this reasonable?")
130+
131+
132+
def get_model_mdn(model_name, scan_layers=True, verbose=False):
133+
"""Initializes a model and retrieves its Muon dimension numbers.
134+
135+
This function sets up the configuration for a given model, initializes the
136+
transformer model, and then extracts the Muon dimension numbers for the model's
137+
weights. It can optionally print verbose debug information.
138+
139+
Args:
140+
model_name: The name of the model to be initialized.
141+
scan_layers: Whether to use layer scanning in the model configuration.
142+
verbose: If True, prints detailed debugging information about the model
143+
structure and Muon dimension numbers.
144+
145+
Returns:
146+
A tree structure containing the Muon dimension numbers for the model's
147+
parameters.
148+
"""
149+
# Setup config
150+
argv = [
151+
None,
152+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
153+
f"model_name={model_name}",
154+
f"scan_layers={scan_layers}",
155+
"attention=dot_product",
156+
]
157+
config = pyconfig.initialize(argv)
158+
# Setup model
159+
devices_array = maxtext_utils.create_device_mesh(config)
160+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
161+
quant = quantizations.configure_quantization(config)
162+
model = Transformer(config, mesh=mesh, quant=quant)
163+
# Get dimension number
164+
muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose)
165+
return muon_weight_dimension_numbers
166+
167+
168+
if __name__ == "__main__":
169+
if len(sys.argv) != 3:
170+
print("Usage: python3 -m MaxText.muon_utils <model_name> <scan_layers:True/False>")
171+
sys.exit(1)
172+
model_name_arg = sys.argv[1]
173+
scan_layers_arg = sys.argv[2].lower() == "true"
174+
get_model_mdn(model_name_arg, scan_layers_arg, verbose=True)

src/MaxText/optimizers.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import jax.numpy as jnp
2020

2121
import optax
22+
from optax.contrib._muon import muon
23+
from MaxText.muon_utils import get_muon_weight_dimension_numbers
2224

2325

24-
def get_optimizer(config, learning_rate_schedule):
26+
def get_optimizer(config, learning_rate_schedule, model=None):
2527
"""Create optimizer."""
2628
if config.opt_type == "adamw":
2729
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
@@ -45,6 +47,29 @@ def get_optimizer(config, learning_rate_schedule):
4547
)
4648
elif config.opt_type == "sgd":
4749
return optax.sgd(learning_rate_schedule)
50+
elif config.opt_type == "muon":
51+
# extract muon dimension number from model structure
52+
if model is not None:
53+
muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config)
54+
else:
55+
raise ValueError("Please specify model to extract muon dimension number.")
56+
muon_kwargs = {
57+
# Shared parameters: "nesterov" uses default
58+
"learning_rate": learning_rate_schedule,
59+
"eps": config.adam_eps,
60+
"mu_dtype": config.mu_dtype,
61+
# Muon-specific parameters: "ns_coeffs", "ns_steps", "weight_decay_mask", "adaptive" uses default
62+
"beta": config.muon_beta,
63+
"weight_decay": config.muon_weight_decay,
64+
"muon_weight_dimension_numbers": muon_weight_dimension_numbers,
65+
"consistent_rms": config.muon_consistent_rms,
66+
# AdamW-specific parameters
67+
"adam_b1": config.adam_b1,
68+
"adam_b2": config.adam_b2,
69+
"adam_eps_root": config.adam_eps_root,
70+
"adam_weight_decay": config.adam_weight_decay,
71+
}
72+
return muon(**muon_kwargs)
4873
else:
4974
raise ValueError(f"{config.opt_type=} is not a supported.")
5075

src/MaxText/pyconfig.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
132132
if key == "run_name" and new_value is None:
133133
new_value = ""
134134

135+
# Preprocess muon_consistent_rms to be None or float
136+
if key == "muon_consistent_rms":
137+
if value in ["None", "none"]:
138+
new_value = None
139+
else:
140+
try:
141+
new_value = float(value)
142+
except ValueError as e:
143+
raise ValueError("muon_consistent_rms should be None or float") from e
144+
135145
pydantic_kwargs[key] = new_value
136146

137147
return pydantic_kwargs
@@ -293,13 +303,8 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
293303

294304
pydantic_kwargs = _prepare_for_pydantic(raw_keys_dict)
295305

296-
if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get(
297-
"use_jax_splash"
298-
):
299-
raise ValueError(
300-
"At most one of `use_tokamax_splash` and `use_jax_splash` can be set to"
301-
" True."
302-
)
306+
if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get("use_jax_splash"):
307+
raise ValueError("At most one of `use_tokamax_splash` and `use_jax_splash` can be set to True.")
303308

304309
# Initialize JAX distributed system before device backend is initialized.
305310
if pydantic_kwargs.get("jax_debug_log_modules"):

src/MaxText/sft/sft_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
144144
with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):
145145
model, mesh = model_creation_utils.create_nnx_model(mt_config)
146146
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config)
147-
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule)
147+
# pass in model for muon
148+
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)
148149

149150
with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
150151
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)

src/MaxText/train_compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def get_shaped_inputs(topology_mesh, config):
9292
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
9393
# The learning_rate_schedule is baked into the compiled object.
9494
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
95-
tx = optimizers.get_optimizer(config, learning_rate_schedule)
95+
# pass in model for muon
96+
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
9697

9798
# Shaped RNG keys
9899
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)

src/MaxText/train_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def create_training_tools(config, model, mesh):
3636
"""Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager."""
3737
init_rng = jax.random.PRNGKey(config.init_weights_seed)
3838
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
39-
tx = optimizers.get_optimizer(config, learning_rate_schedule)
39+
# pass in model for muon
40+
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
4041
logger = checkpointing.setup_checkpoint_logger(config)
4142
if config.enable_multi_tier_checkpointing:
4243
checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager(

0 commit comments

Comments
 (0)