Skip to content

Commit 0fe1adf

Browse files
Merge pull request #3387 from AI-Hypercomputer:xfgu-fixit
PiperOrigin-RevId: 882293460
2 parents 63f5ca9 + 7341b7c commit 0fe1adf

1 file changed

Lines changed: 87 additions & 4 deletions

File tree

tests/unit/train_rl_test.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@
2828
)
2929

3030

31-
def _get_mock_devices(num_devices):
32-
mock_devices = [mock.MagicMock() for _ in range(num_devices)]
33-
for i, d in enumerate(mock_devices):
34-
d.id = i
31+
def _get_mock_devices(devices_per_slice, num_slices=1):
32+
mock_devices = []
33+
for slice_idx in range(num_slices):
34+
for _ in range(devices_per_slice):
35+
d = mock.MagicMock()
36+
d.id = len(mock_devices)
37+
d.slice_index = slice_idx
38+
mock_devices.append(d)
3539
return mock_devices
3640

3741

@@ -93,6 +97,85 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
9397
self.assertEqual(trainer_devices, mock_devices[:2])
9498
self.assertEqual(sampler_devices, mock_devices[2:])
9599

100+
@pytest.mark.cpu_only
101+
def test_setup_configs_and_devices_multislice_not_enough_slices(self):
102+
"""Test setup_configs_and_devices raises ValueError when not enough slices."""
103+
mock_devices = _get_mock_devices(num_slices=2, devices_per_slice=4)
104+
mock_config = SimpleNamespace(
105+
num_trainer_slices=2,
106+
num_samplers_slices=1,
107+
)
108+
109+
def side_effect(argv, **kwargs):
110+
res = SimpleNamespace(**vars(mock_config))
111+
for k, v in kwargs.items():
112+
setattr(res, k, v)
113+
return res
114+
115+
with (
116+
mock.patch.object(jax, "devices", return_value=mock_devices),
117+
mock.patch(
118+
"maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic",
119+
side_effect=side_effect,
120+
),
121+
):
122+
with self.assertRaisesRegex(ValueError, "Not enough slices for trainer and samplers"):
123+
train_rl.setup_configs_and_devices(["dummy", "dummy"])
124+
125+
@pytest.mark.cpu_only
126+
def test_setup_configs_and_devices_multislice_invalid_tp(self):
127+
"""Test setup_configs_and_devices raises ValueError for invalid TP."""
128+
mock_devices = _get_mock_devices(num_slices=4, devices_per_slice=8)
129+
mock_config = SimpleNamespace(
130+
num_trainer_slices=2,
131+
num_samplers_slices=2,
132+
ici_tensor_parallelism=3, # 8 is not divisible by 3
133+
ici_fsdp_parallelism=-1,
134+
)
135+
136+
def side_effect(argv, **kwargs):
137+
res = SimpleNamespace(**vars(mock_config))
138+
for k, v in kwargs.items():
139+
setattr(res, k, v)
140+
return res
141+
142+
with (
143+
mock.patch.object(jax, "devices", return_value=mock_devices),
144+
mock.patch(
145+
"maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic",
146+
side_effect=side_effect,
147+
),
148+
):
149+
with self.assertRaisesRegex(ValueError, "must be divisible by tensor parallelism"):
150+
train_rl.setup_configs_and_devices(["dummy", "dummy"])
151+
152+
@pytest.mark.cpu_only
153+
def test_setup_configs_and_devices_multislice_invalid_tp_fsdp(self):
154+
"""Test setup_configs_and_devices raises ValueError for inconsistent TP and FSDP."""
155+
mock_devices = _get_mock_devices(num_slices=4, devices_per_slice=8)
156+
mock_config = SimpleNamespace(
157+
num_trainer_slices=2,
158+
num_samplers_slices=2,
159+
ici_tensor_parallelism=4,
160+
ici_fsdp_parallelism=3, # 4 * 3 != 8
161+
)
162+
163+
def side_effect(argv, **kwargs):
164+
res = SimpleNamespace(**vars(mock_config))
165+
for k, v in kwargs.items():
166+
setattr(res, k, v)
167+
return res
168+
169+
with (
170+
mock.patch.object(jax, "devices", return_value=mock_devices),
171+
mock.patch(
172+
"maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic",
173+
side_effect=side_effect,
174+
),
175+
):
176+
with self.assertRaisesRegex(ValueError, "must equal devices_per_slice"):
177+
train_rl.setup_configs_and_devices(["dummy", "dummy"])
178+
96179
@pytest.mark.cpu_only
97180
def test_get_rollout_kwargs_no_dp(self):
98181
"""Test case 1: sampler_config.rollout_data_parallelism=-1 -> verify result is calculated."""

0 commit comments

Comments
 (0)