|
35 | 35 | import omegaconf |
36 | 36 |
|
37 | 37 | import benchmarks.maxtext_trillium_model_configs as model_configs |
38 | | -from benchmarks.globals import MAXTEXT_CONFIGS_DIR |
| 38 | +from benchmarks.globals import MAXTEXT_PKG_DIR |
39 | 39 | from benchmarks.command_utils import run_command_with_updates |
40 | 40 | import benchmarks.xla_flags_library as xla_flags |
41 | 41 | from benchmarks.disruption_management.disruption_handler import DisruptionConfig |
@@ -107,7 +107,7 @@ class WorkloadConfig: |
107 | 107 | generate_metrics_and_upload_to_big_query: bool = True |
108 | 108 | hardware_id: str = "v6e" |
109 | 109 | metrics_gcs_file: str = "" |
110 | | - base_config: str = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml") |
| 110 | + base_config: str = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml") |
111 | 111 | topology: str = dataclasses.field(init=False) |
112 | 112 | num_devices_per_slice: int = dataclasses.field(init=False) |
113 | 113 | db_project: str = "" |
@@ -354,7 +354,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict: |
354 | 354 | "xla_flags": f"'{xla_flags_str}'", |
355 | 355 | "dataset": dataset, |
356 | 356 | "run_type": "maxtext-xpk", |
357 | | - "config_file": os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml"), |
| 357 | + "config_file": os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), |
358 | 358 | "topology": wl_config.topology, |
359 | 359 | "tuning_params": f"'{tuning_params_str}'", |
360 | 360 | "db_project": wl_config.db_project, |
@@ -439,8 +439,8 @@ def build_user_command( |
439 | 439 | "export ENABLE_PATHWAYS_PERSISTENCE=1 &&", |
440 | 440 | f"export JAX_PLATFORMS={jax_platforms} &&", |
441 | 441 | "export ENABLE_PJRT_COMPATIBILITY=true &&", |
442 | | - "export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/MaxText MAXTEXT_REPO_ROOT=/deps &&" |
443 | | - f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")}', |
| 442 | + "export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/maxtext MAXTEXT_REPO_ROOT=/deps &&" |
| 443 | + f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")}', |
444 | 444 | f"{config_tuning_params}", |
445 | 445 | f"steps={wl_config.num_steps}", |
446 | 446 | f"model_name={wl_config.model.model_type}", |
|
0 commit comments