Skip to content

Commit cd5cd5e

Browse files
Merge pull request #2945 from ROCm:gw_slurm_multigpu
PiperOrigin-RevId: 875912810
2 parents 5c76bdd + 00e5d30 commit cd5cd5e

2 files changed

Lines changed: 82 additions & 2 deletions

File tree

src/maxtext/utils/max_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Common Max Utils needed by multiple modules.
15+
"""Common Max Utils needed by multiple modules.
1616
All the functions include MaxText modules, such as Pyconfig, should be moved to MaxText utils file."""
1717

1818
import collections
@@ -248,11 +248,20 @@ def initialize_jax_for_gpu(raw_keys):
248248
if os.environ.get("JAX_COORDINATOR_IP") is not None:
249249
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
250250
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
251+
devices = os.getenv("CUDA_VISIBLE_DEVICES")
252+
if devices is not None:
253+
try:
254+
devices = [int(x) for x in devices.split(",")]
255+
except (ValueError, TypeError) as e:
256+
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}")
257+
devices = None
258+
251259
jax.distributed.initialize(
252260
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
253261
num_processes=int(os.getenv("NNODES")),
254262
process_id=int(os.getenv("NODE_RANK")),
255263
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
264+
local_device_ids=devices,
256265
)
257266
max_logging.log(f"JAX global devices: {jax.devices()}")
258267

tests/unit/max_utils_test.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Tests for the common Max Utils """
15+
"""Tests for the common Max Utils"""
16+
import os
1617
import sys
1718
import unittest
1819
import time
1920
import pytest
2021

22+
2123
import jax
2224
from jax import numpy as jnp
2325
from jax import random
@@ -30,6 +32,7 @@
3032
from maxtext.utils import max_utils
3133
from maxtext.utils.train_utils import setup_train_loop
3234
from tests.utils.test_helpers import get_test_config_path
35+
from unittest import mock
3336

3437

3538
class MaxUtilsSummaryStats(unittest.TestCase):
@@ -168,5 +171,73 @@ def test_unscan_train_state_params(self):
168171
self.assertNotIn("layers_0", state.params["params"]["decoder"])
169172

170173

174+
class TestGpuDistributedInitialization(unittest.TestCase):
175+
"""Tests using CUDA_VISIBLE_DEVICES to control which GPUs are used in jax.distributed.initialize."""
176+
177+
@mock.patch.dict(
178+
os.environ,
179+
{
180+
"JAX_COORDINATOR_IP": "10.0.0.1",
181+
"JAX_COORDINATOR_PORT": "1234",
182+
"NNODES": "1",
183+
"NODE_RANK": "0",
184+
"CUDA_VISIBLE_DEVICES": "0,2,3", # Simulating Slurm/orchestrator assignment
185+
},
186+
)
187+
@mock.patch("jax.distributed.initialize")
188+
@mock.patch("jax.devices")
189+
@mock.patch("maxtext.utils.max_logging.log")
190+
def test_initialize_jax_for_gpu_valid_devices(self, _mock_log, _mock_devices, mock_init):
191+
"""Verifies that a comma-separated string of IDs is correctly parsed."""
192+
raw_keys = {"jax_distributed_initialization_timeout": 300}
193+
max_utils.initialize_jax_for_gpu(raw_keys)
194+
# Check that local_device_ids was passed correctly as a list of integers
195+
_, kwargs = mock_init.call_args
196+
self.assertEqual(kwargs["local_device_ids"], [0, 2, 3])
197+
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
198+
199+
@mock.patch.dict(
200+
os.environ,
201+
{
202+
"JAX_COORDINATOR_IP": "10.0.0.1",
203+
"JAX_COORDINATOR_PORT": "1234",
204+
"NNODES": "1",
205+
"NODE_RANK": "0",
206+
"CUDA_VISIBLE_DEVICES": "GPU-8f2e3072-...", # Invalid format for integer parsing
207+
},
208+
)
209+
@mock.patch("jax.distributed.initialize")
210+
@mock.patch("jax.devices")
211+
@mock.patch("maxtext.utils.max_logging.log")
212+
def test_initialize_jax_for_gpu_invalid_devices(self, _mock_log, mock_devices, mock_init):
213+
"""Verifies fallback behavior when parsing fails (e.g., UUIDs)."""
214+
raw_keys = {"jax_distributed_initialization_timeout": 300}
215+
max_utils.initialize_jax_for_gpu(raw_keys)
216+
# Check that it falls back to None (JAX auto-detection default) on error
217+
_, kwargs = mock_init.call_args
218+
self.assertIsNone(kwargs.get("local_device_ids"))
219+
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
220+
221+
@mock.patch.dict(
222+
os.environ,
223+
{
224+
"JAX_COORDINATOR_IP": "10.0.0.1",
225+
"JAX_COORDINATOR_PORT": "1234",
226+
"NNODES": "1",
227+
"NODE_RANK": "0",
228+
},
229+
)
230+
@mock.patch("jax.distributed.initialize")
231+
@mock.patch("jax.devices")
232+
@mock.patch("maxtext.utils.max_logging.log")
233+
def test_initialize_jax_for_gpu_no_devices(self, _mock_log, mock_devices, mock_init):
234+
"""Verifies that no error occurs when CUDA_VISIBLE_DEVICES is not set"""
235+
raw_keys = {"jax_distributed_initialization_timeout": 300}
236+
max_utils.initialize_jax_for_gpu(raw_keys)
237+
_, kwargs = mock_init.call_args
238+
self.assertIsNone(kwargs.get("local_device_ids"))
239+
self.assertEqual(kwargs["coordinator_address"], "10.0.0.1:1234")
240+
241+
171242
if __name__ == "__main__":
172243
unittest.main()

0 commit comments

Comments
 (0)