|
29 | 29 | import omegaconf |
30 | 30 |
|
31 | 31 | from maxtext.configs import pyconfig_deprecated |
32 | | -from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR |
| 32 | +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR |
33 | 33 | from maxtext.common.common_types import DecoderBlockType, ShardMode |
34 | 34 | from maxtext.configs import types |
35 | 35 | from maxtext.configs.types import MaxTextConfig |
|
46 | 46 | # Don't log the following keys. |
47 | 47 | KEYS_NO_LOGGING = ("hf_access_token",) |
48 | 48 |
|
| 49 | +# Module paths to their default config file (relative to MAXTEXT_CONFIGS_DIR). |
| 50 | +_CONFIG_FILE_MAPPING: dict[str, str] = { |
| 51 | + "maxtext.trainers.pre_train.train": "base.yml", |
| 52 | + "maxtext.trainers.pre_train.train_compile": "base.yml", |
| 53 | + "maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml", |
| 54 | + "maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml", |
| 55 | + "maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml", |
| 56 | + "maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml", |
| 57 | + "maxtext.inference.decode": "base.yml", |
| 58 | + "maxtext.inference.decode_multi": "base.yml", |
| 59 | + "maxtext.inference.inference_microbenchmark": "base.yml", |
| 60 | + "maxtext.inference.inference_microbenchmark_sweep": "base.yml", |
| 61 | + "maxtext.inference.maxengine.maxengine_server": "base.yml", |
| 62 | + "maxtext.inference.mlperf.microbenchmarks.benchmark_chunked_prefill": "base.yml", |
| 63 | + "maxtext.inference.vllm_decode": "base.yml", |
| 64 | + "maxtext.checkpoint_conversion.to_maxtext": "base.yml", |
| 65 | + "maxtext.checkpoint_conversion.to_huggingface": "base.yml", |
| 66 | +} |
| 67 | + |
| 68 | + |
| 69 | +def _module_from_path(path: str) -> str | None: |
| 70 | + """Convert a file path to module path for config inference.""" |
| 71 | + real_path = os.path.realpath(path) |
| 72 | + pkg_parent = os.path.realpath(os.path.dirname(MAXTEXT_PKG_DIR)) |
| 73 | + if real_path.startswith(pkg_parent + os.sep): |
| 74 | + relative = os.path.relpath(real_path, pkg_parent) |
| 75 | + return relative.replace(os.sep, ".").removesuffix(".py") |
| 76 | + return None |
| 77 | + |
| 78 | + |
| 79 | +def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: |
| 80 | + """Resolves or infers config file path from module.""" |
| 81 | + if len(argv) >= 2 and argv[1].endswith(".yml"): |
| 82 | + return resolve_config_path(argv[1]), argv[2:] |
| 83 | + module = _module_from_path(argv[0]) |
| 84 | + if module not in _CONFIG_FILE_MAPPING: |
| 85 | + raise ValueError( |
| 86 | + f"No config file provided and no default config found for module '{module}'" |
| 87 | + ) |
| 88 | + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) |
| 89 | + logger.warning("No config file provided, using default config mapping: %s", config_path) |
| 90 | + return config_path, argv[1:] |
| 91 | + |
49 | 92 |
|
50 | 93 | def yaml_key_to_env_key(s: str) -> str: |
51 | 94 | return _MAX_PREFIX + s.upper() |
@@ -227,11 +270,11 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: |
227 | 270 | Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` |
228 | 271 | """ |
229 | 272 | # 1. Load base and inherited configs from file(s) |
230 | | - config_path = resolve_config_path(argv[1]) |
| 273 | + config_path, cli_args = _resolve_or_infer_config(argv) |
231 | 274 | base_yml_config = _load_config(config_path) |
232 | 275 |
|
233 | 276 | # 2. Get overrides from CLI and kwargs |
234 | | - cli_cfg = omegaconf.OmegaConf.from_cli(argv[2:]) |
| 277 | + cli_cfg = omegaconf.OmegaConf.from_cli(cli_args) |
235 | 278 | kwargs_cfg = omegaconf.OmegaConf.create(kwargs) |
236 | 279 | overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg) |
237 | 280 |
|
|
0 commit comments