Skip to content

Commit 6550934

Browse files
Merge pull request #3272 from ROCm:ut-decoupled
PiperOrigin-RevId: 879653255
2 parents 77edafe + a0fc173 commit 6550934

22 files changed

Lines changed: 151 additions & 152 deletions

docs/run_maxtext/decoupled_mode.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ When enabled:
2828
- Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
2929
- Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
3030
- Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
31-
- All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_utils.py`. This helper ensures usage of `decoupled_base_test.yml`.
31+
- Many tests use the helper `get_test_config_path()` from `tests/utils/test_helpers.py`. In decoupled mode, this helper selects `src/maxtext/configs/decoupled_base_test.yml` instead of `src/maxtext/configs/base.yml`.
3232

3333
Minimal datasets included (checked into the repo):
3434

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
# This sentinel is a reminder to choose a real run name.
1616
# If there is already a checkpoint under this run, that checkpoint will auto-resume.
17+
#
18+
# NOTE: Some unit/integration tests in MaxText do not always run this file directly.
19+
# When running in decoupled mode (DECOUPLE_GCLOUD=TRUE), tests may use
20+
# `decoupled_base_test.yml` instead of `base.yml` via `tests/utils/test_helpers.py`.
1721
run_name: ""
1822

1923
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!

tests/integration/checkpointing_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@
3030
from math import isclose
3131
import os.path
3232

33-
import jax
3433
import pytest
3534

3635
from maxtext.common.gcloud_stub import is_decoupled
3736
from maxtext.trainers.pre_train.train import main as train_main
3837
from maxtext.utils.globals import MAXTEXT_PKG_DIR
39-
from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory
38+
from tests.utils.test_helpers import (
39+
get_test_config_path,
40+
get_test_base_output_directory,
41+
get_decoupled_parallelism_overrides,
42+
)
4043

4144

4245
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path):
@@ -72,10 +75,7 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
7275

7376
extra_parallelism = []
7477
if is_decoupled(): # Match device topology in decoupled/local mode
75-
try:
76-
extra_parallelism.append(f"ici_fsdp_parallelism={jax.device_count()}")
77-
except Exception as e: # pragma: no cover - defensive # pylint: disable=broad-exception-caught
78-
print(f"Warning: unable to determine jax.device_count(): {e}")
78+
extra_parallelism.extend(get_decoupled_parallelism_overrides(as_argv=True))
7979

8080
return (
8181
[

tests/integration/smoke/train_gpu_smoke_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from maxtext.common.gcloud_stub import is_decoupled
2222
from maxtext.trainers.pre_train.train import main as train_main
23-
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR
24-
from tests.utils.test_helpers import get_test_dataset_path, get_test_base_output_directory
23+
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
24+
from tests.utils.test_helpers import get_test_dataset_path, get_test_base_output_directory, get_test_config_path
2525

2626

2727
class Train(unittest.TestCase):
@@ -43,7 +43,7 @@ def test_tiny_config(self):
4343
train_main(
4444
[
4545
None,
46-
os.path.join(MAXTEXT_PKG_DIR, "configs", "gpu", "gpu_smoke_test.yml"),
46+
get_test_config_path("gpu/gpu_smoke_test.yml"),
4747
# pylint: disable=f-string-without-interpolation
4848
f"base_output_directory={self.base_output_directory}",
4949
"run_name=runner_test",

tests/integration/train_tests.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from maxtext.common.gcloud_stub import is_decoupled
2323
from maxtext.trainers.pre_train.train import main as train_main
2424
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
25-
from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory
25+
from tests.utils.test_helpers import (
26+
get_test_config_path,
27+
get_test_dataset_path,
28+
get_test_base_output_directory,
29+
get_decoupled_parallelism_overrides,
30+
is_rocm_backend,
31+
)
2632

2733

2834
class TrainTests(unittest.TestCase):
@@ -37,9 +43,9 @@ class TrainTests(unittest.TestCase):
3743
_fsdp_tp4_override = []
3844
if decoupled:
3945
if dev_count >= 4 and dev_count % 4 == 0:
40-
_fsdp_tp4_override = [f"ici_fsdp_parallelism={dev_count // 4}"]
46+
_fsdp_tp4_override = get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count // 4, as_argv=True)
4147
elif dev_count < 4:
42-
_fsdp_tp4_override = [f"ici_fsdp_parallelism={dev_count}"]
48+
_fsdp_tp4_override = get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True)
4349

4450
CONFIGS = {
4551
"base": [ # short test for train.py with TFDS c4
@@ -53,7 +59,7 @@ class TrainTests(unittest.TestCase):
5359
"enable_goodput_recording=False",
5460
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
5561
]
56-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
62+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
5763
"synthetic": [ # tests base config with synthetic dataset
5864
None,
5965
get_test_config_path(),
@@ -66,7 +72,7 @@ class TrainTests(unittest.TestCase):
6672
"dataset_type=synthetic",
6773
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
6874
]
69-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
75+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
7076
"pdb_lt_1": [ # tests base config with per_device_batch_size < 1
7177
None,
7278
get_test_config_path(),
@@ -80,7 +86,7 @@ class TrainTests(unittest.TestCase):
8086
"ici_tensor_parallelism=4",
8187
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
8288
]
83-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
89+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
8490
"tp_transpose": [ # tests base config with ici_tensor_transpose_parallelism=4
8591
None,
8692
get_test_config_path(),
@@ -92,7 +98,7 @@ class TrainTests(unittest.TestCase):
9298
"enable_goodput_recording=False",
9399
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
94100
]
95-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
101+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
96102
"int8": [ # tests base config with int8
97103
None,
98104
get_test_config_path(),
@@ -105,7 +111,7 @@ class TrainTests(unittest.TestCase):
105111
"enable_goodput_recording=False",
106112
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
107113
]
108-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
114+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
109115
"fp8": [ # tests base config with fp8
110116
None,
111117
get_test_config_path(),
@@ -118,7 +124,7 @@ class TrainTests(unittest.TestCase):
118124
"enable_goodput_recording=False",
119125
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
120126
]
121-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
127+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
122128
"nanoo_fp8": [ # tests base config with nanoo_fp8
123129
None,
124130
get_test_config_path(),
@@ -131,7 +137,7 @@ class TrainTests(unittest.TestCase):
131137
"enable_goodput_recording=False",
132138
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
133139
]
134-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
140+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
135141
"te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling
136142
None,
137143
get_test_config_path(),
@@ -144,7 +150,7 @@ class TrainTests(unittest.TestCase):
144150
"enable_goodput_recording=False",
145151
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
146152
]
147-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
153+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
148154
"te_fp8_currentscaling": [ # tests base config with te_fp8_currentscaling
149155
None,
150156
get_test_config_path(),
@@ -157,7 +163,7 @@ class TrainTests(unittest.TestCase):
157163
"enable_goodput_recording=False",
158164
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
159165
]
160-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
166+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
161167
"te_mxfp8": [ # tests base config with te_mxfp8
162168
None,
163169
get_test_config_path(),
@@ -170,7 +176,7 @@ class TrainTests(unittest.TestCase):
170176
"enable_goodput_recording=False",
171177
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
172178
]
173-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
179+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
174180
"dropout": [ # tests base config with dropout
175181
None,
176182
get_test_config_path(),
@@ -185,7 +191,7 @@ class TrainTests(unittest.TestCase):
185191
"dropout_rate=0.02",
186192
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
187193
]
188-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
194+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
189195
"hf_input_pipeline": [ # test for train.py with TFDS c4, using HF input pipeline
190196
None,
191197
get_test_config_path(),
@@ -199,7 +205,7 @@ class TrainTests(unittest.TestCase):
199205
f"hf_train_files={dataset_path}/hf/c4/c4-train-00000-of-01637.parquet",
200206
"tokenizer_path=google-t5/t5-large",
201207
]
202-
+ ([f"ici_fsdp_parallelism={dev_count}"] if decoupled else []),
208+
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
203209
}
204210

205211
@pytest.mark.integration_test
@@ -427,7 +433,7 @@ def test_gpu_optimizer_offload(self):
427433
"enable_goodput_recording=False",
428434
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
429435
]
430-
train_main(optimizer_offload + ([f"ici_fsdp_parallelism={self.dev_count}"] if self.decoupled else []))
436+
train_main(optimizer_offload + get_decoupled_parallelism_overrides(fsdp_parallelism=self.dev_count, as_argv=True))
431437

432438
@pytest.mark.integration_test
433439
@pytest.mark.gpu_only
@@ -448,7 +454,7 @@ def test_gpu_parameter_offload(self):
448454
"enable_goodput_recording=False",
449455
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
450456
]
451-
train_main(parameter_offload + ([f"ici_fsdp_parallelism={self.dev_count}"] if self.decoupled else []))
457+
train_main(parameter_offload + get_decoupled_parallelism_overrides(fsdp_parallelism=self.dev_count, as_argv=True))
452458

453459
@pytest.mark.gpu_only
454460
def test_gpu_cudnn_flash_jax(self):
@@ -567,6 +573,8 @@ def test_gpu_packed_attention(self):
567573
@pytest.mark.gpu_only
568574
@pytest.mark.skip(reason="b/489133823. Previously transient in b/462548581.")
569575
def test_gpu_ring_attention(self):
576+
if is_rocm_backend():
577+
pytest.skip("TE ring attention context parallelism not supported on ROCm.")
570578
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
571579
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" # Disable scan for ring attention
572580
ring_attention = [ # tests base config on GPU with ring attention

tests/unit/attention_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import pytest
4545

4646
from tests.utils import attention_test_util
47-
from tests.utils.test_helpers import get_test_config_path
47+
from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides
4848

4949

5050
class BidirectionalBlockMaskTest(unittest.TestCase):
@@ -290,7 +290,7 @@ def setUp(self):
290290
"""Initializes the configuration for each test"""
291291
super().setUp()
292292
# Conditionally set ici_fsdp_parallelism to match device count in decoupled mode
293-
extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {}
293+
extra_args = get_decoupled_parallelism_overrides()
294294
if not is_decoupled():
295295
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
296296
config = pyconfig.initialize(
@@ -1335,7 +1335,7 @@ def test_projection_initialization(self):
13351335
# Create a copy of the arguments and override the attention_type for the base model
13361336
attention_config_args = self.config_arguments.copy()
13371337
attention_config_args["attention_type"] = AttentionType.GLOBAL.value
1338-
extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {}
1338+
extra_args = get_decoupled_parallelism_overrides()
13391339
attention_cfg = pyconfig.initialize(
13401340
[sys.argv[0], get_test_config_path()],
13411341
**attention_config_args,
@@ -1371,10 +1371,9 @@ def test_projection_initialization(self):
13711371

13721372
# 3. Initialize the MLA layer
13731373
mla_config_args = self.config_arguments.copy()
1374-
mla_extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {}
1374+
mla_extra_args = get_decoupled_parallelism_overrides()
13751375
mla_config_args.update(mla_extra_args)
13761376
_, mla_layer = self.init_mla(mla_config_args, rope_type="default")
1377-
_, mla_layer = self.init_mla(self.config_arguments, rope_type="default")
13781377

13791378
# 4. Assert that the MLA layer DOES NOT HAVE the base projections
13801379
self.assertFalse(hasattr(mla_layer, "query"), "MLA should not have 'query' projection.")

tests/unit/data_loader_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from maxtext.utils.maxtext_utils import create_device_mesh
3131
from maxtext.common.gcloud_stub import is_decoupled
3232
from maxtext.utils.rampup_batch import RampupBatchManager
33-
from tests.utils.test_helpers import get_test_config_path
33+
from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides
3434

3535

3636
class DataLoaderTest(unittest.TestCase):
@@ -63,9 +63,7 @@ def get_test_config(self, reuse_example_batch, **kwargs):
6363
# In decoupled mode, adapt mesh/ICI parallelism so that the
6464
# product of ICI parallelism matches the available devices for
6565
# this test only.
66-
if is_decoupled():
67-
args.setdefault("mesh_axes", ["data"])
68-
args.setdefault("ici_data_parallelism", -1)
66+
args.update(get_decoupled_parallelism_overrides(include_mesh_defaults=True))
6967

7068
return pyconfig.initialize(
7169
[None, get_test_config_path()],

tests/unit/engram_vs_reference_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from typing import List
2929
from dataclasses import dataclass, field
3030
import math
31-
import os
3231
import unittest
3332
from absl.testing import parameterized
3433

@@ -52,7 +51,7 @@
5251
from maxtext.layers.engram import ShortConv as ShortConvJAX
5352
from maxtext.layers.engram import Engram as EngramJAX
5453
from maxtext.utils import maxtext_utils
55-
from maxtext.utils.globals import MAXTEXT_PKG_DIR
54+
from tests.utils.test_helpers import get_test_config_path
5655

5756

5857
def setUpModule():
@@ -470,7 +469,7 @@ def init_torch_weights(module, std=1):
470469
def get_cfg_and_mesh(config):
471470
"""Returns MaxText configuration and mesh."""
472471
cfg = pyconfig.initialize(
473-
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
472+
[None, get_test_config_path()],
474473
run_name="",
475474
enable_checkpointing=False,
476475
model_name="default",

tests/unit/maxtext_utils_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import jax.numpy as jnp
2929
from jax.sharding import Mesh, NamedSharding, PartitionSpec
3030
from maxtext.configs import pyconfig
31-
from maxtext.common.gcloud_stub import is_decoupled
3231
from maxtext.common.common_types import MODEL_MODE_TRAIN
3332
from maxtext.inference import inference_utils
3433
from maxtext.layers import quantizations
@@ -37,7 +36,7 @@
3736
from maxtext.utils import maxtext_utils
3837
from maxtext.utils import sharding
3938
from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations
40-
from tests.utils.test_helpers import get_test_config_path
39+
from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides
4140
import numpy as np
4241
import optax
4342

@@ -347,7 +346,7 @@ class MaxUtilsInitTransformerState(unittest.TestCase):
347346

348347
def setUp(self):
349348
# Conditionally set ici_fsdp_parallelism to match device count in decoupled mode
350-
extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {}
349+
extra_args = get_decoupled_parallelism_overrides()
351350
self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args)
352351
devices_array = maxtext_utils.create_device_mesh(self.config)
353352
self.mesh = Mesh(devices_array, self.config.mesh_axes)
@@ -913,8 +912,10 @@ class TestGetAbstractState(unittest.TestCase):
913912
"""Test class for get_abstract_state."""
914913

915914
def setUp(self):
915+
extra_args = get_decoupled_parallelism_overrides()
916916
self.config = pyconfig.initialize(
917917
[None, get_test_config_path()],
918+
**extra_args,
918919
enable_checkpointing=False,
919920
model_name="llama3.1-8b",
920921
per_device_batch_size=1,

tests/unit/mhc_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616

17-
import os.path
1817
import unittest
1918
import pytest
2019

@@ -26,12 +25,12 @@
2625
import numpy as np
2726

2827
from maxtext.configs import pyconfig
29-
from maxtext.utils.globals import MAXTEXT_PKG_DIR
3028
from maxtext.common.common_types import HyperConnectionType
3129
from maxtext.layers import attention_mla, linears, mhc, moe
3230
from maxtext.layers.initializers import nd_dense_init
3331
from maxtext.layers.normalizations import RMSNorm
3432
from maxtext.utils import maxtext_utils
33+
from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides
3534

3635

3736
class TestExpandReduce(unittest.TestCase):
@@ -92,8 +91,10 @@ class TestMHC(unittest.TestCase):
9291

9392
def setUp(self):
9493
self.dim = 16
94+
extra_args = get_decoupled_parallelism_overrides()
9595
self.config = pyconfig.initialize(
96-
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
96+
[None, get_test_config_path()],
97+
**extra_args,
9798
run_name="test_mhc",
9899
enable_checkpointing=False,
99100
model_name="deepseek-custom",

0 commit comments

Comments
 (0)