|
20 | 20 | from MaxText import pyconfig |
21 | 21 | from MaxText.pyconfig import resolve_config_path |
22 | 22 | from MaxText.globals import MAXTEXT_PKG_DIR |
23 | | -from tests.utils.test_helpers import get_test_config_path |
| 23 | +from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path |
24 | 24 |
|
25 | 25 |
|
26 | 26 | class PyconfigTest(unittest.TestCase): |
@@ -85,6 +85,18 @@ def test_overriding_model(self): |
85 | 85 | self.assertEqual(config.base_emb_dim, 1024) |
86 | 86 | self.assertEqual(config.base_mlp_dim, 24576) |
87 | 87 |
|
| 88 | + def test_overriding_model_in_sft(self): |
| 89 | + # TODO: Update MAXTEXT_PKG_DIR after repo restructuring is complete. |
| 90 | + config = pyconfig.initialize( |
| 91 | + [os.path.join("maxtext.trainers.post_train.sft.train_sft"), get_post_train_test_config_path("sft")], |
| 92 | + skip_jax_distributed_system=True, |
| 93 | + model_name="llama3.1-8b", |
| 94 | + override_model_config=True, |
| 95 | + ) |
| 96 | + |
| 97 | + self.assertEqual(config.base_emb_dim, 4096) |
| 98 | + self.assertEqual(config.base_mlp_dim, 14336) |
| 99 | + |
88 | 100 | def test_resolve_config_path(self): |
89 | 101 | self.assertEqual(resolve_config_path("foo"), os.path.join("src", "foo")) |
90 | 102 | self.assertEqual(resolve_config_path(__file__), __file__) |
|
0 commit comments