You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources.
21
22
22
-
This guide focuses on **response-based knowledge distillation**, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed:
23
+
This tutorial focuses on **response-based knowledge distillation**, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed:
24
+
25
+
1.**Offline Distillation (Dataset Generation):**
23
26
24
-
1.**Offline Distillation (Dataset Generation):**
25
-
* The pre-trained teacher model first generates a new dataset of input-output pairs.
26
-
* The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques.
27
+
- The pre-trained teacher model (running in vLLM) generates a new dataset of input-output pairs.
28
+
- The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques in MaxText.
27
29
28
-
2.**Online Distillation (Logit Matching):**
29
-
* During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously.
30
-
* The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs.
30
+
1.**Online Distillation (Logit Matching):**
31
+
32
+
- During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously.
33
+
- The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs.
31
34
32
35
## Running Offline Distillation with MaxText
33
36
34
-
The following recipe demonstrates the process of offline distillation using **Deepseek2-16b** as the teacher model and **Llama2-7b** as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it's crucial to use the conversational variant for both the teacher and student models. Here’s a step-by-step guide:
37
+
The following recipe demonstrates the process of offline distillation using **Qwen/Qwen3-32B** as the teacher model and **Llama-3.1-8B** as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it's crucial to use the conversational variant for both the teacher and student models. Here's a step-by-step tutorial:
35
38
36
39
### Prerequisites
37
40
38
41
#### a. Setup environment variables
39
42
40
43
```bash
41
-
export HF_TOKEN = <Hugging Face access token>
42
-
export BASE_DIRECTORY = <Directory to store distillation results>
43
-
export HF_REPO_NAME = <Hugging Face repository name to store teacher-generated dataset>
44
-
export USERNAME_OR_ORG = <Owner of Hugging Face repository>
MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint.
60
+
#### c. Setup storage with Hyperdisk
61
+
62
+
To store large models and datasets, attach a Hyperdisk to your TPU VM. Refer to the [Google Cloud Hyperdisk documentation](https://cloud.google.com/compute/docs/disks/add-hyperdisk) and [TPU VM management](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm) for detailed instructions.
Inside the TPU VM, format and mount the disk (if not already mounted):
83
87
84
88
```bash
85
-
huggingface-cli download meta-llama/Llama-2-7b-chat-hf --repo-type model --local-dir ~/llama2-7b-chat
89
+
# Assuming the disk is /dev/sdb, check with lsblk
90
+
sudo mkfs.ext4 /dev/sdb
91
+
sudo mkdir -p /mnt/hyperdisk
92
+
sudo mount /dev/sdb /mnt/hyperdisk
86
93
```
87
94
88
-
#### b. Convert checkpoint to MaxText format
89
-
MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint.
95
+
Update the BASE_DIRECTORY to point to the mounted disk and create the directory:
Once the teacher model's checkpoint is in the MaxText format, you can run inference to generate the dataset that will be used to fine-tune the student model.
103
+
> **Note:** This tutorial uses a mounted Hyperdisk for performance and reproducibility, because writing large model files and many small I/O operations directly to `gs://` can be significantly slower.
102
104
103
-
### 3.a. Run the JetStream server
105
+
### Obtain and prepare the teacher model
104
106
105
-
Example command to run JetStream server on `v4-8`:
107
+
For the teacher model, we will use **vLLM** to run inference. vLLM can load Hugging Face checkpoints directly, so **no conversion to MaxText format is needed** for the teacher. Ensure the teacher model is supported on TPU vLLM (refer to the [vLLM TPU recommended models](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/#text-only-models) for the latest list).
108
+
109
+
You can simply download the model from Hugging Face to your local directory:
huggingface-cli download Qwen/Qwen3-32B --repo-type model --local-dir ${BASE_DIRECTORY}/qwen3-32b
117
114
```
118
115
119
-
Set `multi_sampling` to `True` to generate multiple independent completions per prompt.
116
+
### Obtain and prepare the student model
120
117
118
+
The student model will be trained in MaxText, which uses the Orbax checkpointing format. You will download the Hugging Face weights to your mounted bucket and convert them for training.
121
119
122
-
### 3.b. Generate dataset using JetStream server
123
-
In a new tab in your terminal, run the following command to generate dataset from teacher model. Note that this is an example command to run on `v4-8`:
120
+
#### Convert checkpoint to MaxText format
121
+
122
+
The following command downloads the Hugging Face weights and converts them to the MaxText format.
123
+
124
+
**Note:** This conversion script requires PyTorch.
When `multi_sampling=True` (Step 3.a), the `--num-generations` parameter specifies the number of distinct completions to generate per prompt. The `--batch-size` parameter controls how many prompts are processed per batch, and `--num-batches` defines how many such batches to run. The total number of prompt-completion pairs generated is approximately `num_batches * batch_size * num_generations`.
142
+
### Generate dataset using vLLM (Teacher Step)
143
+
144
+
Use the provided script `generate_distillation_data_vllm.py` to generate the dataset from the teacher model. This script writes a Parquet dataset compatible with MaxText SFT.
For example, with `--batch-size 1024`, `--num-generations 2`, and `--num-batches 200`, this would yield `200 * 1024 * 2 = 409,600` prompt-completion pairs.
It's important to note that some prompts may be filtered out by pre-processing logic before inference. If the prompt sequences are longer than `max-prefill-length`, then those prompts will be filtered out in pre-processing stage.
162
+
```
142
163
143
-
Additionally, the generated dataset can be uploaded to either Hugging Face or Google Cloud Storage (GCS). To upload to Hugging Face, use the `upload-to-hf --hf-repo-id <hf_repo_name>` flags. To upload to GCS, use the `upload-to-gcs --gcs-bucket <gcs bucket name> --gcs-data-path <path in gcs bucket>` flags.
164
+
### Fine-tune the student model using Supervised Fine Tuning (SFT)
144
165
145
-
### 4. Fine-tune the student model using Supervised Fine Tuning (SFT)
146
166
You can now fine-tune your smaller student model using supervised fine-tuning technique in MaxText.
147
167
148
-
###4.a. Fine-tune the student model using dataset generated in Step 3
168
+
#### Fine-tune the student model using the generated dataset
149
169
150
-
Example command to run fine-tuning on v4-8:
170
+
Example command to run fine-tuning on a TPU v6e-8:
###4.b.**[OPTIONAL]** Fine-tune the student model using the original dataset
192
+
####**[OPTIONAL]** Fine-tune the student model using the original dataset
167
193
168
194
The checkpoint from the student model's fine-tuning (on the teacher-generated dataset) can be used for a subsequent fine-tuning stage. In this step, the student model is fine-tuned on the original dataset that was initially provided to the teacher model for generating the dataset.
169
195
170
196
```bash
171
197
# Get the latest checkpoint for fine-tuned student model
0 commit comments