Skip to content

Commit 1238d84

Browse files
Merge pull request #3126 from AI-Hypercomputer:anisha-fix-pyconfig
PiperOrigin-RevId: 869404856
2 parents fa52b6a + f683be3 commit 1238d84

3 files changed

Lines changed: 27 additions & 1 deletion

File tree

src/MaxText/pyconfig.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
233233
if model_name != "default":
234234
# First try relative to base config path
235235
model_config_path = os.path.join(os.path.dirname(config_path), "models", f"{model_name}.yml")
236+
# Try looking for "models" under "src/maxtext/configs/"
237+
if not os.path.isfile(model_config_path):
238+
model_config_path = os.path.join(os.path.dirname(os.path.dirname(config_path)), "models", f"{model_name}.yml")
239+
236240
if not os.path.isfile(model_config_path):
237241
# Fallback to default location within package
238242
dir_path = os.path.dirname(os.path.realpath(__file__))

tests/unit/pyconfig_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from MaxText import pyconfig
2121
from MaxText.pyconfig import resolve_config_path
2222
from MaxText.globals import MAXTEXT_PKG_DIR
23-
from tests.utils.test_helpers import get_test_config_path
23+
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path
2424

2525

2626
class PyconfigTest(unittest.TestCase):
@@ -85,6 +85,18 @@ def test_overriding_model(self):
8585
self.assertEqual(config.base_emb_dim, 1024)
8686
self.assertEqual(config.base_mlp_dim, 24576)
8787

88+
def test_overriding_model_in_sft(self):
89+
# TODO: Update MAXTEXT_PKG_DIR after repo restructuring is complete.
90+
config = pyconfig.initialize(
91+
[os.path.join("maxtext.trainers.post_train.sft.train_sft"), get_post_train_test_config_path("sft")],
92+
skip_jax_distributed_system=True,
93+
model_name="llama3.1-8b",
94+
override_model_config=True,
95+
)
96+
97+
self.assertEqual(config.base_emb_dim, 4096)
98+
self.assertEqual(config.base_mlp_dim, 14336)
99+
88100
def test_resolve_config_path(self):
89101
self.assertEqual(resolve_config_path("foo"), os.path.join("src", "foo"))
90102
self.assertEqual(resolve_config_path(__file__), __file__)

tests/utils/test_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ def get_test_config_path():
3535
return os.path.join(MAXTEXT_CONFIGS_DIR, base_cfg)
3636

3737

38+
def get_post_train_test_config_path(sub_type="sft"):
39+
"""Return absolute path to the chosen test config file.
40+
41+
Returns `decoupled_base_test.yml` when decoupled, otherwise `base.yml`.
42+
"""
43+
base_cfg = "rl.yml" if sub_type == "rl" else "sft.yml"
44+
return os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", base_cfg)
45+
46+
3847
def get_test_dataset_path(cloud_path=None):
3948
"""Return the dataset path for tests.
4049
@@ -70,5 +79,6 @@ def get_test_base_output_directory(cloud_path=None):
7079
__all__ = [
7180
"get_test_base_output_directory",
7281
"get_test_config_path",
82+
"get_post_train_test_config_path",
7383
"get_test_dataset_path",
7484
]

0 commit comments

Comments
 (0)