Skip to content

Commit 94f760d

Browse files
bvandermoonGoogle-ML-Automation
authored andcommitted
Fix config path in maxtext_xpk_runner.py
PiperOrigin-RevId: 878671596
1 parent d9d167e commit 94f760d

10 files changed

Lines changed: 18 additions & 18 deletions

File tree

benchmarks/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os.path
1818

1919
# This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath`
20-
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText")
20+
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/maxtext")
2121

2222
# This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc.
2323
MAXTEXT_REPO_ROOT = os.environ.get(

benchmarks/maxtext_xpk_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import omegaconf
3636

3737
import benchmarks.maxtext_trillium_model_configs as model_configs
38-
from benchmarks.globals import MAXTEXT_CONFIGS_DIR
38+
from benchmarks.globals import MAXTEXT_PKG_DIR
3939
from benchmarks.command_utils import run_command_with_updates
4040
import benchmarks.xla_flags_library as xla_flags
4141
from benchmarks.disruption_management.disruption_handler import DisruptionConfig
@@ -107,7 +107,7 @@ class WorkloadConfig:
107107
generate_metrics_and_upload_to_big_query: bool = True
108108
hardware_id: str = "v6e"
109109
metrics_gcs_file: str = ""
110-
base_config: str = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
110+
base_config: str = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
111111
topology: str = dataclasses.field(init=False)
112112
num_devices_per_slice: int = dataclasses.field(init=False)
113113
db_project: str = ""
@@ -354,7 +354,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
354354
"xla_flags": f"'{xla_flags_str}'",
355355
"dataset": dataset,
356356
"run_type": "maxtext-xpk",
357-
"config_file": os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml"),
357+
"config_file": os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
358358
"topology": wl_config.topology,
359359
"tuning_params": f"'{tuning_params_str}'",
360360
"db_project": wl_config.db_project,
@@ -439,8 +439,8 @@ def build_user_command(
439439
"export ENABLE_PATHWAYS_PERSISTENCE=1 &&",
440440
f"export JAX_PLATFORMS={jax_platforms} &&",
441441
"export ENABLE_PJRT_COMPATIBILITY=true &&",
442-
"export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/MaxText MAXTEXT_REPO_ROOT=/deps &&"
443-
f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")}',
442+
"export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/maxtext MAXTEXT_REPO_ROOT=/deps &&"
443+
f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")}',
444444
f"{config_tuning_params}",
445445
f"steps={wl_config.num_steps}",
446446
f"model_name={wl_config.model.model_type}",

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext maxtext/configs/base.yml \
8282

8383
**Key arguments:**
8484

85-
- `model_name`: The model identifier, which should be defined in `src/MaxText/utils/utils.py`.
85+
- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`.
8686
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
8787
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
8888
- `hf_access_token`: Your Hugging Face token.

docs/guides/optimization/pallas_kernels_performance.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth
5858

5959
- **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.
6060

61-
- [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py)
61+
- [`src/MaxText/kernels/attention/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/attention/splash_attention_kernel.py)
6262

6363
- **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
6464

src/maxtext/examples/sft_llama3_demo.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@
362362
"\n",
363363
"- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n",
364364
"- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n",
365-
"- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation"
365+
"- **Documentation**: Check `src/maxtext/trainers/post_train/sft/train_sft.py` for the `train` function implementation"
366366
]
367367
}
368368
],

src/maxtext/utils/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import os.path
1818

19-
# This is the maxtext package root (src/MaxText)
19+
# This is the maxtext package root (src/maxtext)
2020
# Since this file is at src/maxtext/utils/globals.py, we need to go up 2 levels
2121
MAXTEXT_PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
2222

tests/end_to_end/tpu/deepseek/Run_DeepSeek.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
5656

5757

5858
## Checkpoint conversion
59-
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
60-
* run [convert_deepseek_family_ckpt.py](../../../src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
61-
* run [convert_deepseek_family_unscanned_ckpt.py](../../../src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.
59+
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
60+
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
61+
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.
6262

6363

6464
## Fine-tuning

tests/end_to_end/tpu/gemma/Run_Gemma.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials).
2121

22-
After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [tests/end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu/gemma).
22+
After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [tests/end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu/gemma).
2323

2424
## MaxText supports pretraining and finetuning with high performance
2525

tests/end_to_end/tpu/gemma3/Run_Gemma3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml model_n
2929
```
3030

3131
## Checkpoint Conversion
32-
To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket.
32+
To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket.
3333

3434
## Fine-tuning
3535
After the conversion, you will have a MaxText compatible checkpoint which allows you to fine-tune it with different datasets. One example command to fine-tune a Gemma3-4B model is as follows:

tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,22 @@ hf download [openai/gpt-oss-20b|openai/gpt-oss-120b] --local-dir <local_mxfp4_pa
3131
```
3232

3333

34-
2. Please convert it from MXFP4 to BF16 using script [dequantize_mxfp4.py](../../../src/MaxText/utils/ckpt_scripts/dequantize_mxfp4.py) on gpu.
34+
2. Please convert it from MXFP4 to BF16 using script [dequantize_mxfp4.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/dequantize_mxfp4.py) on gpu.
3535

3636
```
3737
python3 -m maxtext.checkpoint_conversion.standalone_scripts.dequantize_mxfp4 --input-path=<local_mxfp4_path> --output-path=<local_bf16_path> --dtype-str=bf16
3838
```
3939

4040

4141
3. Once downloaded and converted to BF16:
42-
* run [convert_gpt_oss_ckpt.py](../../../src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning.
42+
* run [convert_gpt_oss_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning.
4343

4444
```
4545
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_ckpt --base-model-path <local_bf16_path> \
4646
--maxtext-model-path <GCS/path/to/scanned/maxtext/ckpt> --model-size [gpt-oss-20b|gpt-oss-120b]
4747
```
4848

49-
* run [convert_gpt_oss_unscanned_ckpt.py](../../../src/MaxText/utils/ckpt_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding.
49+
* run [convert_gpt_oss_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding.
5050

5151
```
5252
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path <local_bf16_path> \

0 commit comments

Comments
 (0)