Skip to content

Commit e4e72b3

Browse files
committed
Add support for config flag with no file path
1 parent 0fe1adf commit e4e72b3

2 files changed

Lines changed: 62 additions & 4 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import omegaconf
3030

3131
from maxtext.configs import pyconfig_deprecated
32-
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
32+
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
3333
from maxtext.common.common_types import DecoderBlockType, ShardMode
3434
from maxtext.configs import types
3535
from maxtext.configs.types import MaxTextConfig
@@ -46,6 +46,49 @@
4646
# Don't log the following keys.
4747
KEYS_NO_LOGGING = ("hf_access_token",)
4848

49+
# Module paths to their default config file (relative to MAXTEXT_CONFIGS_DIR).
50+
_CONFIG_FILE_MAPPING: dict[str, str] = {
51+
"maxtext.trainers.pre_train.train": "base.yml",
52+
"maxtext.trainers.pre_train.train_compile": "base.yml",
53+
"maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml",
54+
"maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
55+
"maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
56+
"maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
57+
"maxtext.inference.decode": "base.yml",
58+
"maxtext.inference.decode_multi": "base.yml",
59+
"maxtext.inference.inference_microbenchmark": "base.yml",
60+
"maxtext.inference.inference_microbenchmark_sweep": "base.yml",
61+
"maxtext.inference.maxengine.maxengine_server": "base.yml",
62+
"maxtext.inference.mlperf.microbenchmarks.benchmark_chunked_prefill": "base.yml",
63+
"maxtext.inference.vllm_decode": "base.yml",
64+
"maxtext.checkpoint_conversion.to_maxtext": "base.yml",
65+
"maxtext.checkpoint_conversion.to_huggingface": "base.yml",
66+
}
67+
68+
69+
def _module_from_path(path: str) -> str | None:
70+
"""Convert a file path to module path for config inference."""
71+
real_path = os.path.realpath(path)
72+
pkg_parent = os.path.realpath(os.path.dirname(MAXTEXT_PKG_DIR))
73+
if real_path.startswith(pkg_parent + os.sep):
74+
relative = os.path.relpath(real_path, pkg_parent)
75+
return relative.replace(os.sep, ".").removesuffix(".py")
76+
return None
77+
78+
79+
def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
80+
"""Resolves or infers config file path from module."""
81+
if len(argv) >= 2 and argv[1].endswith(".yml"):
82+
return resolve_config_path(argv[1]), argv[2:]
83+
module = _module_from_path(argv[0])
84+
if module not in _CONFIG_FILE_MAPPING:
85+
raise ValueError(
86+
f"No config file provided and no default config found for module '{module}'"
87+
)
88+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
89+
logger.warning("No config file provided, using default config mapping: %s", config_path)
90+
return config_path, argv[1:]
91+
4992

5093
def yaml_key_to_env_key(s: str) -> str:
5194
return _MAX_PREFIX + s.upper()
@@ -227,11 +270,11 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
227270
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
228271
"""
229272
# 1. Load base and inherited configs from file(s)
230-
config_path = resolve_config_path(argv[1])
273+
config_path, cli_args = _resolve_or_infer_config(argv)
231274
base_yml_config = _load_config(config_path)
232275

233276
# 2. Get overrides from CLI and kwargs
234-
cli_cfg = omegaconf.OmegaConf.from_cli(argv[2:])
277+
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
235278
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
236279
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
237280

tests/unit/pyconfig_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import unittest
2020

2121
from maxtext.configs import pyconfig
22-
from maxtext.configs.pyconfig import resolve_config_path
22+
from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path
2323
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
2424
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path
2525

@@ -115,6 +115,21 @@ def test_resolve_config_path_pip_install(self):
115115
finally:
116116
os.chdir(orig)
117117

118+
def test_config_file_mapping(self):
119+
for module, relative_path in _CONFIG_FILE_MAPPING.items():
120+
full_path = os.path.join(MAXTEXT_CONFIGS_DIR, relative_path)
121+
self.assertTrue(os.path.isfile(full_path), f"Default config for '{module}' not found at {full_path}")
122+
123+
def test_module_from_path(self):
124+
import maxtext.trainers.pre_train.train as train_module
125+
module_file = train_module.__file__
126+
result = _module_from_path(module_file)
127+
self.assertEqual(result, "maxtext.trainers.pre_train.train")
128+
129+
def test_unknown_module_raises(self):
130+
with self.assertRaises(ValueError):
131+
pyconfig.initialize_pydantic(["/custom_rl/module.py", "run_name=test"])
132+
118133

119134
if __name__ == "__main__":
120135
unittest.main()

0 commit comments

Comments
 (0)