Skip to content

Commit e091675

Browse files
Merge pull request #3280 from AI-Hypercomputer:weight_decay
PiperOrigin-RevId: 877653152
2 parents 9f98518 + daaa03c commit e091675

4 files changed

Lines changed: 172 additions & 11 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradie
779779
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
780780
adam_eps_root: 0. # A small constant applied to denominator inside the square root.
781781
adam_weight_decay: 0.1 # AdamW Weight decay
782+
adamw_mask: [] # List of parameter names/patterns to exclude from weight decay in AdamW, like ['bias', '.*norm', '.*ln.*'].
782783
mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inherits from weight_dtype if unset.
783784
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
784785
# See b/399961932 for more.

src/maxtext/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,12 @@ class AdamW(BaseModel):
11751175
description="A small constant for numerical stability (epsilon), applied inside of the square root.",
11761176
)
11771177
adam_weight_decay: float = Field(0.1, description="Weight decay regularization.")
1178+
adamw_mask: list[str] = Field(
1179+
default_factory=list,
1180+
description=(
1181+
"List of parameter names/patterns to exclude from weight decay in AdamW," " like ['bias', '.*norm', '.*ln.*']"
1182+
),
1183+
)
11781184
mu_dtype: str = Field(
11791185
"",
11801186
description="Data type for 'mu' (first moment) in AdamW. Inherits from weight_dtype if empty.",

src/maxtext/optimizers/optimizers.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator, too-many-positional-arguments
1616
""" Utils that are only interesting to MaxText. """
1717

18+
import re
1819
import jax
1920
import jax.numpy as jnp
2021

@@ -23,6 +24,26 @@
2324
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers
2425

2526

27+
def get_adamw_mask(config):
28+
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
29+
if not getattr(config, "adamw_mask", None):
30+
return None
31+
32+
compiled_patterns = [re.compile(pattern) for pattern in config.adamw_mask]
33+
34+
def mask_fn(params):
35+
def _is_decayed(path, _):
36+
# Join path keys into a single string for pattern matching (e.g., "layer1/bias")
37+
path_str = "/".join(str(getattr(p, "key", getattr(p, "idx", getattr(p, "name", p)))) for p in path)
38+
# If any pattern in adamw_mask matches the path, exclude from weight decay (return False).
39+
# Otherwise, apply weight decay (return True).
40+
return not any(pattern.search(path_str) for pattern in compiled_patterns)
41+
42+
return jax.tree_util.tree_map_with_path(_is_decayed, params)
43+
44+
return mask_fn
45+
46+
2647
def get_optimizer(config, learning_rate_schedule, model=None):
2748
"""Create optimizer."""
2849
if config.opt_type == "adamw":
@@ -35,6 +56,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
3556
eps_root=config.adam_eps_root,
3657
weight_decay=config.adam_weight_decay,
3758
mu_dtype=config.mu_dtype,
59+
mask=get_adamw_mask(config),
3860
)
3961
elif config.opt_type == "adam_pax":
4062
return adam_pax(
@@ -44,6 +66,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
4466
epsilon=config.adam_eps,
4567
epsilon_root=config.adam_eps_root,
4668
weight_decay=config.adam_weight_decay,
69+
mask=get_adamw_mask(config),
4770
)
4871
elif config.opt_type == "sgd":
4972
return optax.sgd(learning_rate_schedule)
@@ -81,6 +104,7 @@ def adam_pax(
81104
epsilon: float,
82105
epsilon_root: float,
83106
weight_decay: float,
107+
mask=None,
84108
) -> optax.GradientTransformation:
85109
"""Standard Adam optimizer that supports weight decay.
86110
@@ -162,7 +186,11 @@ def _update_momentum(update, mu, nu):
162186
updates = jax.tree_util.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu)
163187

164188
if weight_decay > 0:
165-
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
189+
if mask is not None:
190+
mask_tree = mask(params) if callable(mask) else mask
191+
updates = jax.tree_util.tree_map(lambda x, v, m: x + weight_decay * v if m else x, updates, params, mask_tree)
192+
else:
193+
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
166194

167195
step_size = -1.0 * learning_rate_fn(count)
168196
# Finally, fold in step size.
Lines changed: 136 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,19 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
16-
"""Unit tests for Muon dimension number generation.
17-
18-
This suite verifies that the automatically generated Muon dimension numbers
19-
for various models match their hardcoded reference values.
20-
python3 -m pytest -v --pyargs tests.muon_test -rP -s
21-
"""
22-
15+
""" Unit tests for all optimizers. """
16+
import re
2317
import unittest
18+
from unittest.mock import patch
19+
import jax
20+
21+
import pytest
2422
from absl.testing import parameterized
2523
from optax.contrib import MuonDimensionNumbers as mdn
24+
25+
from maxtext.configs import pyconfig
26+
from maxtext.optimizers import optimizers
27+
from maxtext.utils import maxtext_utils
2628
from maxtext.utils.muon_utils import get_model_mdn
27-
import pytest
29+
from tests.utils.test_helpers import get_test_config_path
30+
from typing import NamedTuple
31+
2832

2933
# deepseek2, specific: q_lora_rank=0
3034
# applicable: deepseek2-16, but not deepseek2-236b (q_lora_rank=1536)
@@ -214,6 +218,11 @@
214218

215219

216220
class MuonDimensionTest(parameterized.TestCase):
221+
"""Unit tests for Muon dimension number generation.
222+
223+
This suite verifies that the automatically generated Muon dimension numbers
224+
for various models match their hardcoded reference values.
225+
"""
217226

218227
@parameterized.named_parameters(
219228
("deepseek2-16b", "deepseek2-16b", DEEPSEEK2_DIMENSION_NUMBER),
@@ -236,5 +245,122 @@ def test_model_integration(self, model_name, expected_output):
236245
self.assertEqual(actual_output, expected_output)
237246

238247

248+
class AdamWMaskTest(parameterized.TestCase):
249+
"""Tests for the AdamW mask functionality"""
250+
251+
def test_get_adamw_mask_with_empty_mask(self):
252+
"""Directly test the get_adamw_mask function with empty list"""
253+
# Case 1: No mask in config (empty list)
254+
argv = ["", get_test_config_path(), "run_name=test", "adamw_mask=[]"]
255+
config = pyconfig.initialize(argv)
256+
mask_fn = optimizers.get_adamw_mask(config)
257+
self.assertIsNone(mask_fn)
258+
259+
def test_get_adamw_mask_with_valid_mask(self):
260+
"""Directly test the get_adamw_mask function with valid mask"""
261+
# Case 2: Mask in config
262+
argv = ["", get_test_config_path(), "run_name=test", "adamw_mask=['bias', '.*norm', '.*ln.*']"]
263+
config = pyconfig.initialize(argv)
264+
mask_fn = optimizers.get_adamw_mask(config)
265+
self.assertTrue(callable(mask_fn))
266+
267+
params = {"layer1": {"kernel": 1, "bias": 2}, "layer2": {"layer_norm": {"scale": 3}}, "layer3": {"ln": {"scale": 4}}}
268+
mask = mask_fn(params)
269+
self.assertTrue(mask["layer1"]["kernel"])
270+
self.assertFalse(mask["layer1"]["bias"])
271+
self.assertFalse(mask["layer2"]["layer_norm"]["scale"])
272+
self.assertFalse(mask["layer3"]["ln"]["scale"])
273+
274+
def test_get_adamw_mask_with_invalid_mask(self):
275+
"""Test that an invalid regex in the mask config raises an error when used"""
276+
# Create a config with an invalid regex (unbalanced bracket)
277+
argv = ["", get_test_config_path(), "run_name=test", "adamw_mask=['[']"]
278+
config = pyconfig.initialize(argv)
279+
280+
# Applying the mask should raise re.error due to the invalid regex
281+
with self.assertRaises(re.error):
282+
optimizers.get_adamw_mask(config)
283+
284+
def test_get_adamw_mask_with_getattrkey(self):
285+
"""Test that get_adamw_mask correctly handles GetAttrKey (e.g. from NamedTuples)"""
286+
287+
class MyParams(NamedTuple):
288+
kernel: jax.Array
289+
bias: jax.Array
290+
291+
argv = ["", get_test_config_path(), "run_name=test", "adamw_mask=['bias']"]
292+
config = pyconfig.initialize(argv)
293+
mask_fn = optimizers.get_adamw_mask(config)
294+
295+
params = MyParams(kernel=jax.numpy.ones((2, 2)), bias=jax.numpy.zeros((2,)))
296+
mask = mask_fn(params)
297+
298+
self.assertTrue(mask.kernel)
299+
self.assertFalse(mask.bias)
300+
301+
@parameterized.named_parameters(
302+
("adamw", "adamw", "maxtext.optimizers.optimizers.optax.adamw"),
303+
("adam_pax", "adam_pax", "maxtext.optimizers.optimizers.adam_pax"),
304+
)
305+
def test_optimizer_with_mask(self, opt_type, mock_path):
306+
"""Test that optimizer receives the mask function from config and it works as expected"""
307+
# Create a config with a mask list including regex
308+
argv = [
309+
"",
310+
get_test_config_path(),
311+
"run_name=test",
312+
"adamw_mask=['bias', 'layer_norm', 'layer1/.*kernel']",
313+
f"opt_type={opt_type}",
314+
]
315+
config = pyconfig.initialize(argv)
316+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
317+
318+
with patch(mock_path) as mock_opt:
319+
# Call get_optimizer
320+
optimizers.get_optimizer(config, learning_rate_schedule)
321+
322+
# Check that optimizer was called with a mask function
323+
mock_opt.assert_called_once()
324+
_, kwargs = mock_opt.call_args
325+
mask_fn = kwargs["mask"]
326+
327+
# Verify that mask_fn is not None
328+
self.assertIsNotNone(mask_fn)
329+
330+
# Test the behavior of mask_fn
331+
params = {"layer1": {"kernel": 1, "bias": 2}, "layer2": {"layer_norm": {"scale": 3}}, "layer3": [4, 5]}
332+
333+
mask = mask_fn(params)
334+
335+
# kernel in layer1 should be False because of 'layer1/.*kernel'
336+
self.assertFalse(mask["layer1"]["kernel"])
337+
# bias in layer1 should be False because of 'bias'
338+
self.assertFalse(mask["layer1"]["bias"])
339+
# layer_norm should be False because of 'layer_norm'
340+
self.assertFalse(mask["layer2"]["layer_norm"]["scale"])
341+
# layer3 elements should be True
342+
self.assertTrue(mask["layer3"][0])
343+
self.assertTrue(mask["layer3"][1])
344+
345+
@parameterized.named_parameters(
346+
("adamw", "adamw", "maxtext.optimizers.optimizers.optax.adamw"),
347+
("adam_pax", "adam_pax", "maxtext.optimizers.optimizers.adam_pax"),
348+
)
349+
def test_optimizer_without_mask(self, opt_type, mock_path):
350+
"""Test that optimizer receives None for mask when config is empty"""
351+
argv = ["", get_test_config_path(), "run_name=test", f"opt_type={opt_type}"]
352+
config = pyconfig.initialize(argv)
353+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
354+
355+
with patch(mock_path) as mock_opt:
356+
# Call get_optimizer
357+
optimizers.get_optimizer(config, learning_rate_schedule)
358+
359+
# Check that optimizer was called with mask=None
360+
mock_opt.assert_called_once()
361+
_, kwargs = mock_opt.call_args
362+
self.assertIsNone(kwargs["mask"])
363+
364+
239365
if __name__ == "__main__":
240366
unittest.main()

0 commit comments

Comments
 (0)