Skip to content

Commit f02adc1

Browse files
committed
Add soft distillation training script and configuration.
1 parent 2887b75 commit f02adc1

6 files changed

Lines changed: 879 additions & 2 deletions

File tree

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Soft Distillation Configuration
16+
17+
# Inherit MaxText defaults
18+
base_config: "base.yml"
19+
20+
# --- Student Specifics ---
21+
# These are passed as kwargs to the Student config initialization
22+
student_overrides:
23+
model_name: "llama3.1-8b"
24+
25+
# --- Teacher Specifics ---
26+
# These are passed as kwargs to the Teacher config initialization
27+
teacher_overrides:
28+
model_name: "llama3.1-8b"
29+
30+
# --- Distillation Loss ---
31+
distill_alpha: 0.5
32+
distill_temperature: 1.0
33+
34+
# --- Dataset & Tokenizer ---
35+
hf_path: "OptimalScale/ClimbMix"
36+
dataset_type: "hf"
37+
tokenizer_path: "meta-llama/Llama-3.1-8B"
38+
tokenizer_type: "huggingface"
39+
40+
max_target_length: 2048
41+
42+
# --- Training Loop ---
43+
steps: 200000
44+
checkpoint_period: 2000
45+
log_period: 10
46+
save_checkpoint_on_completion: True
47+
48+
# --- Batch Size Strategy ---
49+
# Global Batch Size = per_device_batch_size * num_devices * gradient_accumulation_steps
50+
per_device_batch_size: 2
51+
gradient_accumulation_steps: 8
52+
53+
# --- Learning Rate Schedule ---
54+
learning_rate: 2.0e-4
55+
learning_rate_schedule_steps: 200000
56+
warmup_steps_fraction: 0.1
57+
cosine_learning_rate_final_fraction: 0.1

src/MaxText/configs/types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,24 @@ class FineTuning(BaseModel):
967967
use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.")
968968

969969

970+
class Distillation(BaseModel):
971+
"""Configuration for Knowledge Distillation."""
972+
973+
# --- Overrides ---
974+
# These dictionaries allow flexible configuration injection for Student/Teacher
975+
# without needing to duplicate the entire MaxText schema here.
976+
student_overrides: dict[str, Any] = Field(
977+
default_factory=dict, description="Overrides specific to the Student model (e.g., {'num_query_heads': 16})."
978+
)
979+
teacher_overrides: dict[str, Any] = Field(
980+
default_factory=dict, description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})."
981+
)
982+
983+
# --- Loss Params ---
984+
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
985+
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
986+
987+
970988
class TrainingLoop(BaseModel):
971989
"""Configuration for the main training loop, evaluation, and reproducibility."""
972990

@@ -1634,6 +1652,7 @@ class MaxTextConfig(
16341652
AdamW,
16351653
Muon,
16361654
FineTuning,
1655+
Distillation,
16371656
# Reinforcement Learning
16381657
RLHardware,
16391658
VLLM,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)