File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -59,6 +59,12 @@ def resolve_config_path(param: str) -> str:
5959 lowercase_param = param .replace ("MaxText" , "maxtext" )
6060 if os .path .isfile (lowercase_param ):
6161 return lowercase_param
62+ # For pip-installed packages, strip the src prefix and resolve against
63+ # the installed configs directory (MAXTEXT_CONFIGS_DIR).
64+ if param .startswith ("src/maxtext/configs/" ):
65+ candidate = os .path .join (MAXTEXT_CONFIGS_DIR , param [len ("src/maxtext/configs/" ):])
66+ if os .path .isfile (candidate ):
67+ return candidate
6268 return os .path .join ("src" , param )
6369
6470
Original file line number Diff line number Diff line change 2323Usage Examples:
2424
2525# GRPO on Llama3.1-8B-Instruct
26- python3 -m src. maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
26+ python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
2727 model_name=llama3.1-8b \
2828 tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
2929 load_parameters_path=gs://path/to/checkpoint/0/items \
3232 hf_access_token=${HF_TOKEN?}
3333
3434# GSPO on Llama3.1-70B-Instruct
35- python3 -m src. maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
35+ python3 -m maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
3636 model_name=llama3.1-70b \
3737 tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
3838 load_parameters_path=gs://path/to/checkpoint/0/items \
Original file line number Diff line number Diff line change 1414
1515"""Tests for pyconfig."""
1616
17- import unittest
1817import os .path
18+ import tempfile
19+ import unittest
1920
2021from maxtext .configs import pyconfig
2122from maxtext .configs .pyconfig import resolve_config_path
22- from maxtext .utils .globals import MAXTEXT_PKG_DIR
23+ from maxtext .utils .globals import MAXTEXT_CONFIGS_DIR , MAXTEXT_PKG_DIR
2324from tests .utils .test_helpers import get_test_config_path , get_post_train_test_config_path
2425
2526
@@ -101,6 +102,19 @@ def test_resolve_config_path(self):
101102 self .assertEqual (resolve_config_path ("foo" ), os .path .join ("src" , "foo" ))
102103 self .assertEqual (resolve_config_path (__file__ ), __file__ )
103104
105+ def test_resolve_config_path_pip_install (self ):
106+ """Simulates pip-installed env where cwd has no src/ folder."""
107+ orig = os .getcwd ()
108+ with tempfile .TemporaryDirectory () as tmpdir :
109+ try :
110+ os .chdir (tmpdir )
111+ result = resolve_config_path ("src/maxtext/configs/base.yml" )
112+ self .assertEqual (result , os .path .join (MAXTEXT_CONFIGS_DIR , "base.yml" ))
113+ result = resolve_config_path ("src/maxtext/configs/post_train/rl.yml" )
114+ self .assertEqual (result , os .path .join (MAXTEXT_CONFIGS_DIR , "post_train/rl.yml" ))
115+ finally :
116+ os .chdir (orig )
117+
104118
105119if __name__ == "__main__" :
106120 unittest .main ()
You can’t perform that action at this time.
0 commit comments