Skip to content

Commit 31d0b8c

Browse files
Merge pull request #2960 from AI-Hypercomputer:jackyf/docs/distillation
PiperOrigin-RevId: 863276718
2 parents af14e43 + 4b59129 commit 31d0b8c

2 files changed

Lines changed: 325 additions & 97 deletions

File tree

docs/tutorials/posttraining/knowledge_distillation.md

Lines changed: 128 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -17,160 +17,186 @@
1717
# Knowledge distillation
1818

1919
## Overview
20+
2021
Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources.
2122

22-
This guide focuses on **response-based knowledge distillation**, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed:
23+
This tutorial focuses on **response-based knowledge distillation**, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed:
24+
25+
1. **Offline Distillation (Dataset Generation):**
2326

24-
1. **Offline Distillation (Dataset Generation):**
25-
* The pre-trained teacher model first generates a new dataset of input-output pairs.
26-
* The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques.
27+
- The pre-trained teacher model (running in vLLM) generates a new dataset of input-output pairs.
28+
- The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques in MaxText.
2729

28-
2. **Online Distillation (Logit Matching):**
29-
* During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously.
30-
* The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs.
30+
1. **Online Distillation (Logit Matching):**
31+
32+
- During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously.
33+
- The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs.
3134

3235
## Running Offline Distillation with MaxText
3336

34-
The following recipe demonstrates the process of offline distillation using **Deepseek2-16b** as the teacher model and **Llama2-7b** as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it's crucial to use the conversational variant for both the teacher and student models. Heres a step-by-step guide:
37+
The following recipe demonstrates the process of offline distillation using **Qwen/Qwen3-32B** as the teacher model and **Llama-3.1-8B** as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it's crucial to use the conversational variant for both the teacher and student models. Here's a step-by-step tutorial:
3538

3639
### Prerequisites
3740

3841
#### a. Setup environment variables
3942

4043
```bash
41-
export HF_TOKEN = <Hugging Face access token>
42-
export BASE_DIRECTORY = <Directory to store distillation results>
43-
export HF_REPO_NAME = <Hugging Face repository name to store teacher-generated dataset>
44-
export USERNAME_OR_ORG = <Owner of Hugging Face repository>
45-
export RUN_NAME = <unique name for the run>
44+
export HF_TOKEN=<your-hf-token> # e.g., hf_BA6...
45+
export RUN_NAME=<your-run-name> # e.g., distill-20260115
4646
```
4747

4848
#### b. Install dependencies
4949

50-
```sh
51-
git clone https://github.com/AI-Hypercomputer/maxtext.git
52-
python3 -m venv ~/venv-maxtext
53-
source ~/venv-maxtext/bin/activate
54-
python3 -m pip install uv
55-
cd maxtext
56-
uv pip install -r dependencies/requirements/requirements.txt
57-
```
50+
To install MaxText and its dependencies for post-training (including vLLM for the teacher), run the following:
5851

59-
### 1. Obtain and prepare the teacher model
52+
1. Follow the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#install-maxtext).
6053

61-
#### a. Download model from Hugging Face
54+
1. Install the additional dependencies for post-training:
6255

6356
```bash
64-
huggingface-cli login # Provide your Hugging Face token
65-
huggingface-cli download deepseek-ai/DeepSeek-V2-Lite-Chat --repo-type model --local-dir ~/deepseek2-16b-chat
57+
bash tools/setup/setup_post_training_requirements.sh
6658
```
6759

68-
#### b. Convert checkpoint to MaxText format
69-
MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint.
60+
#### c. Setup storage with Hyperdisk
61+
62+
To store large models and datasets, attach a Hyperdisk to your TPU VM. Refer to the [Google Cloud Hyperdisk documentation](https://cloud.google.com/compute/docs/disks/add-hyperdisk) and [TPU VM management](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm) for detailed instructions.
63+
64+
First, create a Hyperdisk:
7065

7166
```bash
72-
# Get unscanned checkpoint for efficient decoding
73-
JAX_PLATFORMS=cpu \
74-
python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_unscanned_ckpt \
75-
--base_model_path ~/deepseek2-16b-chat \
76-
--maxtext_model_path ${BASE_DIRECTORY}/deepseek2-16-chat/unscanned \
77-
--model_size deepseek2-16b
67+
export ZONE=<your-tpu-zone> # e.g., us-central1-a
68+
export TPU_VM_NAME=<your-tpu-vm-name>
69+
export DISK_NAME=<your-disk-name> # e.g., my-hyperdisk
70+
export DISK_SIZE=<disk-size> # e.g., 500GB
71+
72+
gcloud compute disks create ${DISK_NAME} \
73+
--size=${DISK_SIZE} \
74+
--type=hyperdisk-balanced \
75+
--zone=${ZONE}
7876
```
7977

80-
### 2. Obtain and prepare the student model
78+
Then, attach the disk to your TPU VM:
79+
80+
```bash
81+
gcloud compute instances attach-disk ${TPU_VM_NAME} \
82+
--disk=${DISK_NAME} \
83+
--zone=${ZONE}
84+
```
8185

82-
#### a. Download model from Hugging Face
86+
Inside the TPU VM, format and mount the disk (if not already mounted):
8387

8488
```bash
85-
huggingface-cli download meta-llama/Llama-2-7b-chat-hf --repo-type model --local-dir ~/llama2-7b-chat
89+
# Assuming the disk is /dev/sdb, check with lsblk
90+
sudo mkfs.ext4 /dev/sdb
91+
sudo mkdir -p /mnt/hyperdisk
92+
sudo mount /dev/sdb /mnt/hyperdisk
8693
```
8794

88-
#### b. Convert checkpoint to MaxText format
89-
MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint.
95+
Update the BASE_DIRECTORY to point to the mounted disk and create the directory:
9096

9197
```bash
92-
# Get scanned checkpoint for fine-tuning
93-
JAX_PLATFORMS=cpu \
94-
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt \
95-
--base-model-path ~/llama2-7b-chat \
96-
--maxtext-model-path ${BASE_DIRECTORY}/llama2-7b-chat/scanned \
97-
--model-size llama2-7b
98+
export BASE_NAME=<your-base-directory> # e.g., knowledge-distillation
99+
export BASE_DIRECTORY=/mnt/hyperdisk/${BASE_NAME}
100+
mkdir -p ${BASE_DIRECTORY}
98101
```
99102

100-
### 3. Generate dataset using the teacher model
101-
Once the teacher model's checkpoint is in the MaxText format, you can run inference to generate the dataset that will be used to fine-tune the student model.
103+
> **Note:** This tutorial uses a mounted Hyperdisk for performance and reproducibility, because writing large model files and many small I/O operations directly to `gs://` can be significantly slower.
102104
103-
### 3.a. Run the JetStream server
105+
### Obtain and prepare the teacher model
104106

105-
Example command to run JetStream server on `v4-8`:
107+
For the teacher model, we will use **vLLM** to run inference. vLLM can load Hugging Face checkpoints directly, so **no conversion to MaxText format is needed** for the teacher. Ensure the teacher model is supported on TPU vLLM (refer to the [vLLM TPU recommended models](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/#text-only-models) for the latest list).
108+
109+
You can simply download the model from Hugging Face to your local directory:
106110

107111
```bash
108-
python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \
109-
tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \
110-
load_parameters_path=${BASE_DIRECTORY}/deepseek2-16-chat/unscanned/0/items \
111-
model_name=deepseek2-16b \
112-
per_device_batch_size=10 ici_tensor_parallelism=4 \
113-
max_target_length=2048 max_prefill_predict_length=64 \
114-
hf_access_token=$HF_TOKEN \
115-
scan_layers=False \
116-
multi_sampling=True decode_sampling_strategy=weighted
112+
huggingface-cli login --token $HF_TOKEN
113+
huggingface-cli download Qwen/Qwen3-32B --repo-type model --local-dir ${BASE_DIRECTORY}/qwen3-32b
117114
```
118115

119-
Set `multi_sampling` to `True` to generate multiple independent completions per prompt.
116+
### Obtain and prepare the student model
120117

118+
The student model will be trained in MaxText, which uses the Orbax checkpointing format. You will download the Hugging Face weights to your mounted bucket and convert them for training.
121119

122-
### 3.b. Generate dataset using JetStream server
123-
In a new tab in your terminal, run the following command to generate dataset from teacher model. Note that this is an example command to run on `v4-8`:
120+
#### Convert checkpoint to MaxText format
121+
122+
The following command downloads the Hugging Face weights and converts them to the MaxText format.
123+
124+
**Note:** This conversion script requires PyTorch.
124125

125126
```bash
126-
python3 -m MaxText.generate_distillation_data \
127-
--tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \
128-
--dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft \
129-
--data-columns messages \
130-
--max-prefill-length 64 --max-target-length 2048 \
131-
--hf-access-token $HF_TOKEN \
132-
--use-chat-template --remove-local-dataset-files \
133-
--num-generations 2 --batch-size 1024 --num-batches 200 \
134-
upload-to-hf --hf-repo-id ${HF_REPO_NAME}
127+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
128+
```
129+
130+
```bash
131+
# Set the checkpoint directory
132+
export PRE_TRAINED_MODEL_CKPT_DIRECTORY=${BASE_DIRECTORY}/llama3.1-8b-ckpt
133+
134+
# Convert to MaxText format
135+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
136+
model_name=llama3.1-8b \
137+
hf_access_token=${HF_TOKEN} \
138+
base_output_directory=${PRE_TRAINED_MODEL_CKPT_DIRECTORY} \
139+
scan_layers=True skip_jax_distributed_system=True
135140
```
136141

137-
When `multi_sampling=True` (Step 3.a), the `--num-generations` parameter specifies the number of distinct completions to generate per prompt. The `--batch-size` parameter controls how many prompts are processed per batch, and `--num-batches` defines how many such batches to run. The total number of prompt-completion pairs generated is approximately `num_batches * batch_size * num_generations`.
142+
### Generate dataset using vLLM (Teacher Step)
143+
144+
Use the provided script `generate_distillation_data_vllm.py` to generate the dataset from the teacher model. This script writes a Parquet dataset compatible with MaxText SFT.
145+
146+
Run the generation script:
147+
148+
```bash
149+
export OUTPUT_DATASET=${BASE_DIRECTORY}/datasets/distillation_data.parquet
138150

139-
For example, with `--batch-size 1024`, `--num-generations 2`, and `--num-batches 200`, this would yield `200 * 1024 * 2 = 409,600` prompt-completion pairs.
151+
python3 -m tools.data_generation.generate_distillation_data_vllm \
152+
--dataset-path HuggingFaceH4/ultrachat_200k \
153+
--data-split train_sft \
154+
--data-columns messages \
155+
--hf-access-token $HF_TOKEN \
156+
--teacher-model ${BASE_DIRECTORY}/qwen3-32b \
157+
--use-chat-template \
158+
--num-prompts 5120 \
159+
--num-generations 2 \
160+
--output-file ${OUTPUT_DATASET}
140161

141-
It's important to note that some prompts may be filtered out by pre-processing logic before inference. If the prompt sequences are longer than `max-prefill-length`, then those prompts will be filtered out in pre-processing stage.
162+
```
142163

143-
Additionally, the generated dataset can be uploaded to either Hugging Face or Google Cloud Storage (GCS). To upload to Hugging Face, use the `upload-to-hf --hf-repo-id <hf_repo_name>` flags. To upload to GCS, use the `upload-to-gcs --gcs-bucket <gcs bucket name> --gcs-data-path <path in gcs bucket>` flags.
164+
### Fine-tune the student model using Supervised Fine Tuning (SFT)
144165

145-
### 4. Fine-tune the student model using Supervised Fine Tuning (SFT)
146166
You can now fine-tune your smaller student model using supervised fine-tuning technique in MaxText.
147167

148-
### 4.a. Fine-tune the student model using dataset generated in Step 3
168+
#### Fine-tune the student model using the generated dataset
149169

150-
Example command to run fine-tuning on v4-8:
170+
Example command to run fine-tuning on a TPU v6e-8:
151171

152172
```bash
153-
python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \
173+
python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
154174
run_name=${RUN_NAME} \
155-
base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \
156-
tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \
157-
hf_path=${USERNAME_OR_ORG}/${HF_REPO_NAME} \
158-
train_split='train' train_data_columns=['prompt','completion'] \
159-
load_parameters_path=${BASE_DIRECTORY}/llama2-7b-chat/scanned/0/items \
160-
model_name=llama2-7b \
161-
per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
175+
base_output_directory=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b \
176+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \
177+
dataset_type=hf \
178+
hf_path=parquet \
179+
hf_train_files=${OUTPUT_DATASET} \
180+
train_split='train' \
181+
train_data_columns=['messages'] \
182+
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_DIRECTORY}/0/items \
183+
model_name=llama3.1-8b \
184+
per_device_batch_size=2 \
185+
steps=200 \
186+
ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
162187
max_target_length=2048 \
163-
hf_access_token=$HF_TOKEN
188+
hf_access_token=$HF_TOKEN \
189+
profiler=xplane
164190
```
165191

166-
### 4.b. **[OPTIONAL]** Fine-tune the student model using the original dataset
192+
#### **[OPTIONAL]** Fine-tune the student model using the original dataset
167193

168194
The checkpoint from the student model's fine-tuning (on the teacher-generated dataset) can be used for a subsequent fine-tuning stage. In this step, the student model is fine-tuned on the original dataset that was initially provided to the teacher model for generating the dataset.
169195

170196
```bash
171197
# Get the latest checkpoint for fine-tuned student model
172-
CHECKPOINTS_PATH=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b/${RUN_NAME}/checkpoints
173-
checkpoints=$(gcloud storage ls $CHECKPOINTS_PATH)
198+
CHECKPOINTS_PATH=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b/${RUN_NAME}/checkpoints
199+
checkpoints=$(ls $CHECKPOINTS_PATH)
174200
integer_dirs=()
175201
for dir in $checkpoints; do
176202
dir_name=$(basename "$dir")
@@ -180,18 +206,23 @@ for dir in $checkpoints; do
180206
done
181207
sorted_dirs=($(printf '%s\n' "${integer_dirs[@]}" | sort -n))
182208
largest_dir="${sorted_dirs[-1]}"
183-
FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/items
209+
FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/model_params
184210

185211
# Fine-tune student model on original dataset
186-
python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \
187-
run_name=${RUN_NAME} \
188-
base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \
189-
tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \
212+
python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
213+
run_name=${RUN_NAME}_stage2 \
214+
base_output_directory=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b \
215+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \
216+
dataset_type=hf \
190217
hf_path='HuggingFaceH4/ultrachat_200k' \
191-
train_split='train_sft' train_data_columns=['messages'] \
218+
train_split='train_sft' \
219+
train_data_columns=['messages'] \
192220
load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \
193-
model_name=llama2-7b \
194-
per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
221+
model_name=llama3.1-8b \
222+
per_device_batch_size=2 \
223+
steps=200 \
224+
ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
195225
max_target_length=2048 \
196-
hf_access_token=$HF_TOKEN
226+
hf_access_token=$HF_TOKEN \
227+
profiler=xplane
197228
```

0 commit comments

Comments
 (0)