Skip to content

Commit 81bbd64

Browse files
Merge pull request #3360 from AI-Hypercomputer:xfgu-fixit
PiperOrigin-RevId: 881496802
2 parents fd03282 + 44ccec9 commit 81bbd64

1 file changed

Lines changed: 98 additions & 0 deletions

File tree

tests/unit/train_rl_test.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for train_rl.py."""
16+
17+
import unittest
18+
from unittest import mock
19+
import pytest
20+
from types import SimpleNamespace
21+
import jax
22+
23+
24+
# Same as in rl_utils_test.py.
25+
train_rl = pytest.importorskip(
26+
"maxtext.trainers.post_train.rl.train_rl",
27+
reason="Tunix is not installed on the GPU image",
28+
)
29+
30+
31+
def _get_mock_devices(num_devices):
32+
mock_devices = [mock.MagicMock() for _ in range(num_devices)]
33+
for i, d in enumerate(mock_devices):
34+
d.id = i
35+
return mock_devices
36+
37+
38+
class TrainRLTest(unittest.TestCase):
39+
"""Tests for train_rl.py."""
40+
41+
@pytest.mark.cpu_only
42+
def test_setup_configs_and_devices_pathways_split(self):
43+
"""Test setup_configs_and_devices with multiple VMs and Pathways."""
44+
mock_devices = _get_mock_devices(8)
45+
46+
mock_config = SimpleNamespace(
47+
num_trainer_slices=-1,
48+
num_samplers_slices=-1,
49+
chips_per_vm=4,
50+
use_pathways=True,
51+
trainer_devices_fraction=0.5,
52+
sampler_devices_fraction=0.5,
53+
)
54+
55+
# Following the pattern in distillation_checkpointing_test.py for mocking jax objects
56+
with (
57+
mock.patch.object(jax, "devices", return_value=mock_devices),
58+
mock.patch("maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", return_value=mock_config),
59+
):
60+
trainer_config, sampler_config, trainer_devices, sampler_devices = train_rl.setup_configs_and_devices(
61+
["dummy", "dummy"]
62+
)
63+
64+
self.assertEqual(trainer_config, mock_config)
65+
self.assertEqual(sampler_config, mock_config)
66+
self.assertEqual(len(trainer_devices), 4)
67+
self.assertEqual(len(sampler_devices), 4)
68+
self.assertEqual(trainer_devices, mock_devices[:4])
69+
self.assertEqual(sampler_devices, mock_devices[4:])
70+
71+
@pytest.mark.cpu_only
72+
def test_setup_configs_and_devices_pathways_fractional_split(self):
73+
"""Test setup_configs_and_devices with multiple VMs and custom fractions."""
74+
mock_devices = _get_mock_devices(8)
75+
76+
mock_config = SimpleNamespace(
77+
num_trainer_slices=-1,
78+
num_samplers_slices=-1,
79+
chips_per_vm=4,
80+
use_pathways=True,
81+
trainer_devices_fraction=0.25,
82+
sampler_devices_fraction=0.75,
83+
)
84+
85+
with (
86+
mock.patch.object(jax, "devices", return_value=mock_devices),
87+
mock.patch("maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", return_value=mock_config),
88+
):
89+
_, _, trainer_devices, sampler_devices = train_rl.setup_configs_and_devices(["dummy", "dummy"])
90+
91+
self.assertEqual(len(trainer_devices), 2)
92+
self.assertEqual(len(sampler_devices), 6)
93+
self.assertEqual(trainer_devices, mock_devices[:2])
94+
self.assertEqual(sampler_devices, mock_devices[2:])
95+
96+
97+
if __name__ == "__main__":
98+
unittest.main()

0 commit comments

Comments
 (0)