Skip to content

Commit dc63663

Browse files
author
Sharon Yu
committed
Fix: Update Fine-tuning tutorial and structure
1 parent 9d52020 commit dc63663

1 file changed

Lines changed: 76 additions & 27 deletions

File tree

docs/tutorials/posttraining/full_finetuning.md

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,44 @@ The high level steps involve:
3131
a disk or gcs bucket. Or it can also input data directly from the hugging face
3232
dataset.
3333
- Running the training script with the checkpoint
34-
- Note: You may need to change the training parameters to fit the model to the
35-
TPU or GPU shape and to obtain an optimized performance.
34+
- Note: Training parameters may require adjustment to align the model with the specific TPU or GPU topology and achieve optimal performance.
3635

3736
## MaxText checkpoints
3837

3938
MaxText checkpoints are in their own format. You can see the format in the script for llama conversion script.
4039

40+
### Meta's PyTorch checkpoint to Maxtext (Orbax) checkpoint
41+
4142
The conversion scripts for LLama work with Meta’s original checkpoints and not with HuggingFace Checkpoint.
4243

43-
E.g.
44+
#### Pre-requist
45+
- Download the Meta format checkpoints
46+
47+
Option 1: Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory.
48+
49+
Option 2: Download the checkpoint from a GCS Bucket to a local directoty with command ```gcloud storage cp -r <GCS path for META format checkpoint> <local/path>``` .
50+
51+
- Install Torch CPU because TPU or GPU is not required in this convertion script.
52+
53+
```python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu```
54+
55+
- Setup Environment Variables
56+
57+
```bash
58+
export CONVERTED_CHECKPOINT_PATH=<GCS path for saving converted checkpoint> # e.g., gs://my-bucket/my-model-checkpoint
59+
export LOCAL_META_CHECKPOINT_PATH=<local path for META checkpoint> # e.g., /local/meta-ckpt
60+
```
61+
#### Running the weight conversion script
62+
63+
Using 11ama-7b as an example:
4464

4565
```bash
46-
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path <path/to/meta/ckpt> \
47-
--maxtext-model-path <GCS/path/to/save/new/maxtext/ckpt> --model-size llama2-7b
66+
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt \
67+
--base-model-path ${LOCAL_META_CHECKPOINT_PATH} \
68+
--model-size llama2-7b \
69+
--maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
4870
```
71+
Note:
4972

5073
The conversion scripts do not use accelerators but need large host memory to perform the conversion.
5174

@@ -54,48 +77,74 @@ The conversion scripts do not use accelerators but need large host memory to per
5477
- For large size model (e.g. 70B model), this script requires large memory VM.
5578
- The script load and save weights in a single pass.
5679

57-
### Sample full fine tuning script
80+
### MaxText checkpoint to Hugging Face
5881

59-
Below is a sample training script for LLama2-7b.
82+
Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py).
6083

61-
```bash
62-
python3 -m MaxText.train \
63-
src/MaxText/configs/base.yml \
64-
run_name="llama2-finetune-maxtext" \
65-
base_output_directory=${output_directory} \
66-
load_parameters_path=${path_to_checkpoint} \
67-
model_name='llama2-7b' \
68-
dataset_path=${dataset_path} \
69-
async_checkpointing=False \
70-
model_name='llama2-7b' \
71-
steps=10 per_device_batch_size=.25
72-
```
84+
#### Sample for coverting Maxtext format weight to Hugging Face format
7385

74-
You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu).
75-
These scripts can provide a reference point for various scripts.
86+
- Setup Environment Variables
7687

77-
### MaxText checkpoint to Hugging Face
88+
```bash
89+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
90+
export PATH_TO_CHECKPOINT=<GCS path for saving converted checkpoint>/0/items # e.g., ${CONVERTED_CHECKPOINT_PATH}/0/items
91+
export HF_MODLE_PATH=<local path for hf> # e.g., /local/convert_ckp
92+
```
93+
- Running the conversion script
7894

79-
Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py).
95+
The following example is executing a v6e-8 TPU VM with llama2-7b.
8096

81-
#### Dataset
97+
```bash
98+
python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf \
99+
src/MaxText/configs/base.yml \
100+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
101+
load_parameters_path=${PATH_TO_CHECKPOINT} \
102+
run_name="mxt-2-hf" \
103+
model_name='llama2-7b' \
104+
hardware=tpu \
105+
hf_model_path=${HF_MODLE_PATH}
106+
107+
```
108+
### Dataset
82109

83110
MaxText provides examples to work with [Common Crawl](https://commoncrawl.org/). The dataset is available in TFRecords format in a cloud bucket. MaxText provides scripts to copy the dataset to a Google Cloud Storage Bucket.
84111

85112
##### Common Crawl (c4) dataset setup
86113

87-
You need to run these steps once per project prior to any local development or cluster experiments.
114+
Run these steps once per project prior to any local development or cluster experiments.
88115

89116
1. Create two gcs buckets in your project, one for downloading and retrieving the dataset and the other for storing the logs.
90117
2. Download the dataset in your gcs bucket
91118

92119
MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them:
93120

94121
```bash
95-
bash tools/data_generation/download_dataset.sh ${GCS_PROJECT?} ${GCS_BUCKET_NAME?}
122+
export PROJECT=<Google Cloud Project ID>
123+
export DATASET_GCS_BUCKET=<GCS for dataset> # e.g., gs://my-bucket/my-dataset
124+
125+
bash tools/data_generation/download_dataset.sh ${PROJECT} ${DATASET_GCS_BUCKET}
126+
```
127+
128+
The above will download the c4 dataset to the GCS BUCKET.
129+
130+
### Sample full fine tuning script
131+
132+
Below is a sample training script for LLama2-7b on v6e-8 TPU VM.
133+
134+
```bash
135+
python3 -m MaxText.train \
136+
src/MaxText/configs/base.yml \
137+
run_name="llama2-finetune-maxtext" \
138+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
139+
load_parameters_path=${PATH_TO_CHECKPOINT} \
140+
model_name='llama2-7b' \
141+
dataset_path=${DATASET_GCS_BUCKET} \
142+
async_checkpointing=False \
143+
steps=10 per_device_batch_size=1
96144
```
97145

98-
The above will download the c4 dataset to your GCS BUCKET.
146+
You can find some [end to end scripts here](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu).
147+
These scripts can provide a reference point for various scripts.
99148

100149
## Parameters to achieve high MFU
101150

0 commit comments

Comments
 (0)