1- # Copyright 2023–2025 Google LLC
1+ # Copyright 2023–2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15-
16- """Unit tests for Muon dimension number generation.
17-
18- This suite verifies that the automatically generated Muon dimension numbers
19- for various models match their hardcoded reference values.
20- python3 -m pytest -v --pyargs tests.muon_test -rP -s
21- """
22-
15+ """ Unit tests for all optimizers. """
16+ import re
2317import unittest
18+ from unittest .mock import patch
19+ import jax
20+
21+ import pytest
2422from absl .testing import parameterized
2523from optax .contrib import MuonDimensionNumbers as mdn
24+
25+ from maxtext .configs import pyconfig
26+ from maxtext .optimizers import optimizers
27+ from maxtext .utils import maxtext_utils
2628from maxtext .utils .muon_utils import get_model_mdn
27- import pytest
29+ from tests .utils .test_helpers import get_test_config_path
30+ from typing import NamedTuple
31+
2832
2933# deepseek2, specific: q_lora_rank=0
3034# applicable: deepseek2-16, but not deepseek2-236b (q_lora_rank=1536)
214218
215219
216220class MuonDimensionTest (parameterized .TestCase ):
221+ """Unit tests for Muon dimension number generation.
222+
223+ This suite verifies that the automatically generated Muon dimension numbers
224+ for various models match their hardcoded reference values.
225+ """
217226
218227 @parameterized .named_parameters (
219228 ("deepseek2-16b" , "deepseek2-16b" , DEEPSEEK2_DIMENSION_NUMBER ),
@@ -236,5 +245,122 @@ def test_model_integration(self, model_name, expected_output):
236245 self .assertEqual (actual_output , expected_output )
237246
238247
248+ class AdamWMaskTest (parameterized .TestCase ):
249+ """Tests for the AdamW mask functionality"""
250+
251+ def test_get_adamw_mask_with_empty_mask (self ):
252+ """Directly test the get_adamw_mask function with empty list"""
253+ # Case 1: No mask in config (empty list)
254+ argv = ["" , get_test_config_path (), "run_name=test" , "adamw_mask=[]" ]
255+ config = pyconfig .initialize (argv )
256+ mask_fn = optimizers .get_adamw_mask (config )
257+ self .assertIsNone (mask_fn )
258+
259+ def test_get_adamw_mask_with_valid_mask (self ):
260+ """Directly test the get_adamw_mask function with valid mask"""
261+ # Case 2: Mask in config
262+ argv = ["" , get_test_config_path (), "run_name=test" , "adamw_mask=['bias', '.*norm', '.*ln.*']" ]
263+ config = pyconfig .initialize (argv )
264+ mask_fn = optimizers .get_adamw_mask (config )
265+ self .assertTrue (callable (mask_fn ))
266+
267+ params = {"layer1" : {"kernel" : 1 , "bias" : 2 }, "layer2" : {"layer_norm" : {"scale" : 3 }}, "layer3" : {"ln" : {"scale" : 4 }}}
268+ mask = mask_fn (params )
269+ self .assertTrue (mask ["layer1" ]["kernel" ])
270+ self .assertFalse (mask ["layer1" ]["bias" ])
271+ self .assertFalse (mask ["layer2" ]["layer_norm" ]["scale" ])
272+ self .assertFalse (mask ["layer3" ]["ln" ]["scale" ])
273+
274+ def test_get_adamw_mask_with_invalid_mask (self ):
275+ """Test that an invalid regex in the mask config raises an error when used"""
276+ # Create a config with an invalid regex (unbalanced bracket)
277+ argv = ["" , get_test_config_path (), "run_name=test" , "adamw_mask=['[']" ]
278+ config = pyconfig .initialize (argv )
279+
280+ # Applying the mask should raise re.error due to the invalid regex
281+ with self .assertRaises (re .error ):
282+ optimizers .get_adamw_mask (config )
283+
284+ def test_get_adamw_mask_with_getattrkey (self ):
285+ """Test that get_adamw_mask correctly handles GetAttrKey (e.g. from NamedTuples)"""
286+
287+ class MyParams (NamedTuple ):
288+ kernel : jax .Array
289+ bias : jax .Array
290+
291+ argv = ["" , get_test_config_path (), "run_name=test" , "adamw_mask=['bias']" ]
292+ config = pyconfig .initialize (argv )
293+ mask_fn = optimizers .get_adamw_mask (config )
294+
295+ params = MyParams (kernel = jax .numpy .ones ((2 , 2 )), bias = jax .numpy .zeros ((2 ,)))
296+ mask = mask_fn (params )
297+
298+ self .assertTrue (mask .kernel )
299+ self .assertFalse (mask .bias )
300+
301+ @parameterized .named_parameters (
302+ ("adamw" , "adamw" , "maxtext.optimizers.optimizers.optax.adamw" ),
303+ ("adam_pax" , "adam_pax" , "maxtext.optimizers.optimizers.adam_pax" ),
304+ )
305+ def test_optimizer_with_mask (self , opt_type , mock_path ):
306+ """Test that optimizer receives the mask function from config and it works as expected"""
307+ # Create a config with a mask list including regex
308+ argv = [
309+ "" ,
310+ get_test_config_path (),
311+ "run_name=test" ,
312+ "adamw_mask=['bias', 'layer_norm', 'layer1/.*kernel']" ,
313+ f"opt_type={ opt_type } " ,
314+ ]
315+ config = pyconfig .initialize (argv )
316+ learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (config )
317+
318+ with patch (mock_path ) as mock_opt :
319+ # Call get_optimizer
320+ optimizers .get_optimizer (config , learning_rate_schedule )
321+
322+ # Check that optimizer was called with a mask function
323+ mock_opt .assert_called_once ()
324+ _ , kwargs = mock_opt .call_args
325+ mask_fn = kwargs ["mask" ]
326+
327+ # Verify that mask_fn is not None
328+ self .assertIsNotNone (mask_fn )
329+
330+ # Test the behavior of mask_fn
331+ params = {"layer1" : {"kernel" : 1 , "bias" : 2 }, "layer2" : {"layer_norm" : {"scale" : 3 }}, "layer3" : [4 , 5 ]}
332+
333+ mask = mask_fn (params )
334+
335+ # kernel in layer1 should be False because of 'layer1/.*kernel'
336+ self .assertFalse (mask ["layer1" ]["kernel" ])
337+ # bias in layer1 should be False because of 'bias'
338+ self .assertFalse (mask ["layer1" ]["bias" ])
339+ # layer_norm should be False because of 'layer_norm'
340+ self .assertFalse (mask ["layer2" ]["layer_norm" ]["scale" ])
341+ # layer3 elements should be True
342+ self .assertTrue (mask ["layer3" ][0 ])
343+ self .assertTrue (mask ["layer3" ][1 ])
344+
345+ @parameterized .named_parameters (
346+ ("adamw" , "adamw" , "maxtext.optimizers.optimizers.optax.adamw" ),
347+ ("adam_pax" , "adam_pax" , "maxtext.optimizers.optimizers.adam_pax" ),
348+ )
349+ def test_optimizer_without_mask (self , opt_type , mock_path ):
350+ """Test that optimizer receives None for mask when config is empty"""
351+ argv = ["" , get_test_config_path (), "run_name=test" , f"opt_type={ opt_type } " ]
352+ config = pyconfig .initialize (argv )
353+ learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (config )
354+
355+ with patch (mock_path ) as mock_opt :
356+ # Call get_optimizer
357+ optimizers .get_optimizer (config , learning_rate_schedule )
358+
359+ # Check that optimizer was called with mask=None
360+ mock_opt .assert_called_once ()
361+ _ , kwargs = mock_opt .call_args
362+ self .assertIsNone (kwargs ["mask" ])
363+
364+
239365if __name__ == "__main__" :
240366 unittest .main ()
0 commit comments