Skip to content

Commit 39462d5

Browse files
Merge pull request #2880 from AI-Hypercomputer:xibin/doc
PiperOrigin-RevId: 850325633
2 parents fbcee8f + 71f660f commit 39462d5

2 files changed

Lines changed: 24 additions & 6 deletions

File tree

docs/tutorials/posttraining/sft.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs:
7575
### Option 2: Converting a Hugging Face checkpoint
7676
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
7777

78-
1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
78+
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
7979

8080
```sh
81-
export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items
81+
export PRE_TRAINED_MODEL_CKPT_DIRECTORY=${BASE_OUTPUT_DIRECTORY}/maxtext-checkpoint
8282
```
8383

8484
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
@@ -89,10 +89,16 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu #
8989
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
9090
model_name=${PRE_TRAINED_MODEL} \
9191
hf_access_token=${HF_TOKEN} \
92-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \
92+
base_output_directory=${PRE_TRAINED_MODEL_CKPT_DIRECTORY} \
9393
scan_layers=True skip_jax_distributed_system=True
9494
```
9595

96+
3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint:
97+
98+
```sh
99+
export PRE_TRAINED_MODEL_CKPT_PATH=${PRE_TRAINED_MODEL_CKPT_DIRECTORY}/0/items
100+
```
101+
96102
## Run SFT on Hugging Face Dataset
97103
Now you are ready to run SFT using the following command:
98104

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-b
107107
### Option 2: Converting a Hugging Face checkpoint
108108
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
109109

110-
1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.
110+
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
111111

112112
```bash
113-
export MODEL_CHECKPOINT_PATH=${OUTPUT_PATH}/${WORKLOAD_NAME}/maxtext-checkpoint/0/items
113+
export MODEL_CHECKPOINT_DIRECTORY=${OUTPUT_PATH}/maxtext-checkpoint
114114
```
115115

116116
2. **Run the Conversion Script:** Execute the following commands on a CPU machine that downloads the specified HuggingFace model and converts its weights into the MaxText format. This command will download the HuggingFace model and convert it to the MaxText format, saving it to the specified GCS bucket. The conversion script only supports official versions of models from HuggingFace. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
@@ -122,7 +122,19 @@ USE_OCDBT=<Flag to use ocdbt> # True to run SFT with McJAX, False to run SFT wit
122122
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
123123

124124
# For large models, it is recommended to set `--lazy_load_tensors` flag to reduce memory usage during conversion
125-
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True checkpoint_storage_use_zarr3=$USE_ZARR3 checkpoint_storage_use_ocdbt=$USE_OCDBT skip_jax_distributed_system=True --lazy_load_tensors=True
125+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
126+
model_name=$MODEL_NAME \
127+
hf_access_token=$HF_TOKEN \
128+
base_output_directory=$MODEL_CHECKPOINT_DIRECTORY \
129+
scan_layers=True \
130+
checkpoint_storage_use_zarr3=$USE_ZARR3 checkpoint_storage_use_ocdbt=$USE_OCDBT \
131+
skip_jax_distributed_system=True --lazy_load_tensors=True
132+
```
133+
134+
3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint:
135+
136+
```bash
137+
export MODEL_CHECKPOINT_PATH=${MODEL_CHECKPOINT_DIRECTORY}/0/items
126138
```
127139

128140
## 6. Submit workload on GKE cluster

0 commit comments

Comments
 (0)