|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -""" Tests for the common Max Utils """ |
| 15 | +"""Tests for the common Max Utils""" |
| 16 | +import os |
16 | 17 | import sys |
17 | 18 | import unittest |
18 | 19 | import time |
19 | 20 | import pytest |
20 | 21 |
|
| 22 | + |
21 | 23 | import jax |
22 | 24 | from jax import numpy as jnp |
23 | 25 | from jax import random |
|
30 | 32 | from maxtext.utils import max_utils |
31 | 33 | from maxtext.utils.train_utils import setup_train_loop |
32 | 34 | from tests.utils.test_helpers import get_test_config_path |
| 35 | +from unittest import mock |
33 | 36 |
|
34 | 37 |
|
35 | 38 | class MaxUtilsSummaryStats(unittest.TestCase): |
@@ -168,5 +171,73 @@ def test_unscan_train_state_params(self): |
168 | 171 | self.assertNotIn("layers_0", state.params["params"]["decoder"]) |
169 | 172 |
|
170 | 173 |
|
| 174 | +class TestGpuDistributedInitialization(unittest.TestCase): |
| 175 | + """Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize.""" |
| 176 | + |
| 177 | + @mock.patch.dict( |
| 178 | + os.environ, |
| 179 | + { |
| 180 | + "JAX_COORDINATOR_IP": "10.0.0.1", |
| 181 | + "JAX_COORDINATOR_PORT": "1234", |
| 182 | + "NNODES": "1", |
| 183 | + "NODE_RANK": "0", |
| 184 | + "CUDA_VISIBLE_DEVICES": "0,2,3", # Simulating Slurm/orchestrator assignment |
| 185 | + }, |
| 186 | + ) |
| 187 | + @mock.patch("jax.distributed.initialize") |
| 188 | + @mock.patch("jax.devices") |
| 189 | + @mock.patch("maxtext.utils.max_logging.log") |
| 190 | + def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mock_init): |
| 191 | + """Verifies that a comma-separated string of IDs is correctly parsed.""" |
| 192 | + raw_keys = {"jax_distributed_initialization_timeout": 300} |
| 193 | + max_utils.initialize_jax_for_gpu(raw_keys) |
| 194 | + # Check that local_device_ids was passed correctly as a list of integers |
| 195 | + _, kwargs = mock_init.call_args |
| 196 | + self.assertEqual(kwargs["local_device_ids"], [0, 2, 3]) |
| 197 | + self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234") |
| 198 | + |
| 199 | + @mock.patch.dict( |
| 200 | + os.environ, |
| 201 | + { |
| 202 | + "JAX_COORDINATOR_IP": "10.0.0.1", |
| 203 | + "JAX_COORDINATOR_PORT": "1234", |
| 204 | + "NNODES": "1", |
| 205 | + "NODE_RANK": "0", |
| 206 | + "CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", # Invalid format for integer parsing |
| 207 | + }, |
| 208 | + ) |
| 209 | + @mock.patch("jax.distributed.initialize") |
| 210 | + @mock.patch("jax.devices") |
| 211 | + @mock.patch("maxtext.utils.max_logging.log") |
| 212 | + def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, mock_init): |
| 213 | + """Verifies fallback behavior when parsing fails (e.g., UUIDs).""" |
| 214 | + raw_keys = {"jax_distributed_initialization_timeout": 300} |
| 215 | + max_utils.initialize_jax_for_gpu(raw_keys) |
| 216 | + # Check that it falls back to None (JAX auto-detection default) on error |
| 217 | + _, kwargs = mock_init.call_args |
| 218 | + self.assertIsNone(kwargs.get("local_device_ids")) |
| 219 | + self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234") |
| 220 | + |
| 221 | + @mock.patch.dict( |
| 222 | + os.environ, |
| 223 | + { |
| 224 | + "JAX_COORDINATOR_IP": "10.0.0.1", |
| 225 | + "JAX_COORDINATOR_PORT": "1234", |
| 226 | + "NNODES": "1", |
| 227 | + "NODE_RANK": "0", |
| 228 | + }, |
| 229 | + ) |
| 230 | + @mock.patch("jax.distributed.initialize") |
| 231 | + @mock.patch("jax.devices") |
| 232 | + @mock.patch("maxtext.utils.max_logging.log") |
| 233 | + def test_initialize_jax_for_gpu_no_devices(self, _mock_log, mock_devices, mock_init): |
| 234 | + """Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set""" |
| 235 | + raw_keys = {"jax_distributed_initialization_timeout": 300} |
| 236 | + max_utils.initialize_jax_for_gpu(raw_keys) |
| 237 | + _, kwargs = mock_init.call_args |
| 238 | + self.assertIsNone(kwargs.get("local_device_ids")) |
| 239 | + self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234") |
| 240 | + |
| 241 | + |
171 | 242 | if __name__ == "__main__": |
172 | 243 | unittest.main() |
0 commit comments