Skip to content

Commit d2c172a

Browse files
Merge pull request #3356 from AI-Hypercomputer:fix-pypi-config-resolve
PiperOrigin-RevId: 881115051
2 parents b125113 + 5333a5e commit d2c172a

3 files changed

Lines changed: 24 additions & 4 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def resolve_config_path(param: str) -> str:
5959
lowercase_param = param.replace("MaxText", "maxtext")
6060
if os.path.isfile(lowercase_param):
6161
return lowercase_param
62+
# For pip-installed packages, strip the src prefix and resolve against
63+
# the installed configs directory (MAXTEXT_CONFIGS_DIR).
64+
if param.startswith("src/maxtext/configs/"):
65+
candidate = os.path.join(MAXTEXT_CONFIGS_DIR, param[len("src/maxtext/configs/"):])
66+
if os.path.isfile(candidate):
67+
return candidate
6268
return os.path.join("src", param)
6369

6470

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Usage Examples:
2424
2525
# GRPO on Llama3.1-8B-Instruct
26-
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
26+
python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
2727
model_name=llama3.1-8b \
2828
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
2929
load_parameters_path=gs://path/to/checkpoint/0/items \
@@ -32,7 +32,7 @@
3232
hf_access_token=${HF_TOKEN?}
3333
3434
# GSPO on Llama3.1-70B-Instruct
35-
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
35+
python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
3636
model_name=llama3.1-70b \
3737
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
3838
load_parameters_path=gs://path/to/checkpoint/0/items \

tests/unit/pyconfig_test.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
"""Tests for pyconfig."""
1616

17-
import unittest
1817
import os.path
18+
import tempfile
19+
import unittest
1920

2021
from maxtext.configs import pyconfig
2122
from maxtext.configs.pyconfig import resolve_config_path
22-
from maxtext.utils.globals import MAXTEXT_PKG_DIR
23+
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
2324
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path
2425

2526

@@ -101,6 +102,19 @@ def test_resolve_config_path(self):
101102
self.assertEqual(resolve_config_path("foo"), os.path.join("src", "foo"))
102103
self.assertEqual(resolve_config_path(__file__), __file__)
103104

105+
def test_resolve_config_path_pip_install(self):
106+
"""Simulates pip-installed env where cwd has no src/ folder."""
107+
orig = os.getcwd()
108+
with tempfile.TemporaryDirectory() as tmpdir:
109+
try:
110+
os.chdir(tmpdir)
111+
result = resolve_config_path("src/maxtext/configs/base.yml")
112+
self.assertEqual(result, os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml"))
113+
result = resolve_config_path("src/maxtext/configs/post_train/rl.yml")
114+
self.assertEqual(result, os.path.join(MAXTEXT_CONFIGS_DIR, "post_train/rl.yml"))
115+
finally:
116+
os.chdir(orig)
117+
104118

105119
if __name__ == "__main__":
106120
unittest.main()

0 commit comments

Comments
 (0)