Skip to content

Commit 50f0082

Browse files
author
Sharon Yu
committed
resolve comments
1 parent ddd85c6 commit 50f0082

1 file changed

Lines changed: 66 additions & 77 deletions

File tree

docs/tutorials/posttraining/full_finetuning.md

Lines changed: 66 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,111 +14,100 @@
1414
limitations under the License.
1515
-->
1616

17-
(full-finetuning)=
1817
# Full fine-tuning on single-host TPUs
1918

20-
MaxText can perform pre-training and full finetuning. To perform full fine
21-
tuning, you need to pass the checkpoint to the training script.
19+
Full Fine-Tuning (FFT) is a common technique used in post-training to adapt a pre-trained Large Language Model (LLM) to a specific downstream task or dataset. In this process, all the parameters (weights) of the original model are "unfrozen" and updated during training on the new task-specific data. This allows the entire model to adjust and specialize, potentially leading to the best performance on the new task.
2220

23-
Following is the parameter to assign a checkpoint to the training script.
21+
This tutorial demonstrates step-by-step instructions for setting up the environment, convert checkpoint and then training the model on a Hugging Face dataset using FFT.
2422

25-
- `load_parameters_path`: Path to the checkpoint directory
23+
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
2624

27-
The high level steps involve:
28-
- Converting the model checkpoints to MaxText formatted checkpoints
29-
- Preparing the dataset so that data can be fed into the training script.
30-
MaxText provides sample pipelines to load the data via tf.data or Pygrain from
31-
a disk or gcs bucket. Or it can also input data directly from the hugging face
32-
dataset.
33-
- Running the training script with the checkpoint
34-
- Note: Training parameters may require adjustment to align the model with the specific TPU or GPU topology and achieve optimal performance.
25+
## Install dependencies
3526

36-
## MaxText checkpoints
27+
```sh
28+
# 1. Clone the repository
29+
git clone https://github.com/AI-Hypercomputer/maxtext.git
30+
cd maxtext
3731

38-
MaxText checkpoints are in their own format. You can see the format in the script for llama conversion script.
32+
# 2. Create virtual environment
33+
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
34+
pip install uv
35+
uv venv --python 3.12 --seed $VENV_NAME
36+
source $VENV_NAME/bin/activate
3937

40-
### Meta's PyTorch checkpoint to Maxtext (Orbax) checkpoint
41-
42-
The conversion scripts for LLama work with Meta’s original checkpoints and not with HuggingFace Checkpoint.
43-
44-
#### Pre-requisite
45-
- Download the Meta format checkpoints.
46-
47-
Option 1: Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory.
48-
49-
Option 2: Download the checkpoint from a GCS Bucket to a local directoty with command ```gcloud storage cp -r <GCS path for META format checkpoint> <local/path>``` .
50-
51-
- Install Torch CPU because TPU or GPU is not required in this convertion script.
38+
# 3. Install dependencies in editable mode
39+
uv pip install -e .[tpu] --resolution=lowest
40+
install_maxtext_github_deps
41+
```
42+
## Setup environment variables
5243

53-
```python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu```
44+
```sh
45+
# -- Model configuration --
46+
export MODEL_NAME=<model name> # e.g., 'llama2-7b'
47+
export HF_TOKEN=<Hugging Face access token>
5448

55-
- Setup Environment Variables.
49+
# -- MaxText configuration --
50+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
51+
```
5652

57-
```bash
58-
export CONVERTED_CHECKPOINT_PATH=<GCS path for saving converted checkpoint> # e.g., gs://my-bucket/my-model-checkpoint
59-
export LOCAL_META_CHECKPOINT_PATH=<local path for META checkpoint> # e.g., /local/meta-ckpt
60-
```
61-
#### Running the weight conversion script
53+
## Hugging Face checkpoint to Maxtext checkpoint
54+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
6255

63-
Using llama-7b as an example:
56+
### Option 1: Using an existing MaxText checkpoint
57+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
6458

65-
```bash
66-
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt \
67-
--base-model-path ${LOCAL_META_CHECKPOINT_PATH} \
68-
--model-size llama2-7b \
69-
--maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
59+
```sh
60+
export MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
7061
```
71-
Note:
7262

73-
The conversion scripts do not use accelerators but need large host memory to perform the conversion.
63+
### Option 2: Converting a Hugging Face checkpoint
64+
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
7465

75-
- The base model checkpoints should be in the format `{name}.{chkpt_idx}.pth`
76-
- For example: `mistral-7b.00.pth`
77-
- For large size model (e.g. 70B model), this script requires large memory VM.
78-
- The script load and save weights in a single pass.
66+
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
7967

80-
### MaxText checkpoint to Hugging Face
81-
82-
Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py).
68+
```sh
69+
export MODEL_CKPT_DIRECTORY=${BASE_OUTPUT_DIRECTORY}/maxtext-checkpoint
70+
```
8371

84-
#### Sample for coverting Maxtext format weight to Hugging Face format
72+
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
8573

86-
- Setup Environment Variables
74+
```sh
75+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure torch is installed for the conversion script
8776

88-
```bash
89-
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
90-
export PATH_TO_CHECKPOINT=<GCS path for saving converted checkpoint>/0/items # e.g., ${CONVERTED_CHECKPOINT_PATH}/0/items
91-
export HF_MODLE_PATH=<local path for hf> # e.g., /local/convert_ckp
92-
```
93-
- Running the conversion script
77+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
78+
model_name=${MODEL_NAME} \
79+
hf_access_token=${HF_TOKEN} \
80+
base_output_directory=${MODEL_CKPT_DIRECTORY} \
81+
scan_layers=True skip_jax_distributed_system=True
82+
```
83+
## MaxText checkpoint to Hugging Face checkpoint
9484

95-
Below is a sample for LLama2-7b on v6e-8 TPU VM.
85+
Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem.
9686

97-
```bash
98-
python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf \
99-
src/MaxText/configs/base.yml \
100-
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
101-
load_parameters_path=${PATH_TO_CHECKPOINT} \
102-
run_name="mxt-2-hf" \
103-
model_name='llama2-7b' \
104-
hardware=tpu \
105-
hf_model_path=${HF_MODLE_PATH}
106-
107-
```
108-
### Dataset
87+
```sh
88+
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
89+
model_name=${MODEL_NAME} \
90+
load_parameters_path=${MODEL_CKPT_PATH}$ \
91+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
92+
scan_layers=false \
93+
use_multimodal=false \
94+
hf_access_token=${HF_TOKEN} \
95+
weight_dtype=bfloat16
96+
```
97+
## Dataset
10998

11099
MaxText provides examples to work with [Common Crawl](https://commoncrawl.org/). The dataset is available in TFRecords format in a cloud bucket. MaxText provides scripts to copy the dataset to a Google Cloud Storage Bucket.
111100

112-
##### Common Crawl (c4) dataset setup
101+
### Common Crawl (c4) dataset setup
113102

114103
Run these steps once per project prior to any local development or cluster experiments.
115104

116105
1. Create two gcs buckets in your project, one for downloading and retrieving the dataset and the other for storing the logs.
117-
2. Download the dataset in your gcs bucket
106+
2. Download the dataset in your gcs bucket.
118107

119-
MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them:
108+
MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them.
120109

121-
```bash
110+
```sh
122111
export PROJECT=<Google Cloud Project ID>
123112
export DATASET_GCS_BUCKET=<GCS for dataset> # e.g., gs://my-bucket/my-dataset
124113

@@ -127,16 +116,16 @@ bash tools/data_generation/download_dataset.sh ${PROJECT} ${DATASET_GCS_BUCKET}
127116

128117
The above will download the c4 dataset to the GCS BUCKET.
129118

130-
### Sample full fine tuning script
119+
## Sample Full Fine tuning script
131120

132121
Below is a sample training script for LLama2-7b on v6e-8 TPU VM.
133122

134-
```bash
123+
```sh
135124
python3 -m MaxText.train \
136125
src/MaxText/configs/base.yml \
137126
run_name="llama2-finetune-maxtext" \
138127
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
139-
load_parameters_path=${PATH_TO_CHECKPOINT} \
128+
load_parameters_path=${MODEL_CKPT_PATH} \
140129
model_name='llama2-7b' \
141130
dataset_path=${DATASET_GCS_BUCKET} \
142131
async_checkpointing=False \

0 commit comments

Comments
 (0)