Skip to content

Commit 724b115

Browse files
Merge pull request #2901 from AI-Hypercomputer:agagik-distill
PiperOrigin-RevId: 854336099
2 parents 4423abf + f02adc1 commit 724b115

6 files changed

Lines changed: 879 additions & 2 deletions

File tree

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Soft Distillation Configuration
16+
17+
# Inherit MaxText defaults
18+
base_config: "base.yml"
19+
20+
# --- Student Specifics ---
21+
# These are passed as kwargs to the Student config initialization
22+
student_overrides:
23+
model_name: "llama3.1-8b"
24+
25+
# --- Teacher Specifics ---
26+
# These are passed as kwargs to the Teacher config initialization
27+
teacher_overrides:
28+
model_name: "llama3.1-8b"
29+
30+
# --- Distillation Loss ---
31+
distill_alpha: 0.5
32+
distill_temperature: 1.0
33+
34+
# --- Dataset & Tokenizer ---
35+
hf_path: "OptimalScale/ClimbMix"
36+
dataset_type: "hf"
37+
tokenizer_path: "meta-llama/Llama-3.1-8B"
38+
tokenizer_type: "huggingface"
39+
40+
max_target_length: 2048
41+
42+
# --- Training Loop ---
43+
steps: 200000
44+
checkpoint_period: 2000
45+
log_period: 10
46+
save_checkpoint_on_completion: True
47+
48+
# --- Batch Size Strategy ---
49+
# Global Batch Size = per_device_batch_size * num_devices * gradient_accumulation_steps
50+
per_device_batch_size: 2
51+
gradient_accumulation_steps: 8
52+
53+
# --- Learning Rate Schedule ---
54+
learning_rate: 2.0e-4
55+
learning_rate_schedule_steps: 200000
56+
warmup_steps_fraction: 0.1
57+
cosine_learning_rate_final_fraction: 0.1

src/MaxText/configs/types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,24 @@ class FineTuning(BaseModel):
973973
use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.")
974974

975975

976+
class Distillation(BaseModel):
977+
"""Configuration for Knowledge Distillation."""
978+
979+
# --- Overrides ---
980+
# These dictionaries allow flexible configuration injection for Student/Teacher
981+
# without needing to duplicate the entire MaxText schema here.
982+
student_overrides: dict[str, Any] = Field(
983+
default_factory=dict, description="Overrides specific to the Student model (e.g., {'num_query_heads': 16})."
984+
)
985+
teacher_overrides: dict[str, Any] = Field(
986+
default_factory=dict, description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})."
987+
)
988+
989+
# --- Loss Params ---
990+
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
991+
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
992+
993+
976994
class TrainingLoop(BaseModel):
977995
"""Configuration for the main training loop, evaluation, and reproducibility."""
978996

@@ -1636,6 +1654,7 @@ class MaxTextConfig(
16361654
AdamW,
16371655
Muon,
16381656
FineTuning,
1657+
Distillation,
16391658
# Reinforcement Learning
16401659
RLHardware,
16411660
VLLM,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)