1+ # Copyright 2023–2025 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # https://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ # Soft Distillation Configuration
16+
17+ # Inherit MaxText defaults
18+ base_config : " base.yml"
19+
20+ # --- Student Specifics ---
21+ # These are passed as kwargs to the Student config initialization
22+ student_overrides :
23+ model_name : " llama3.1-8b"
24+
25+ # --- Teacher Specifics ---
26+ # These are passed as kwargs to the Teacher config initialization
27+ teacher_overrides :
28+ model_name : " llama3.1-8b"
29+
30+ # --- Distillation Loss ---
31+ distill_alpha : 0.5
32+ distill_temperature : 1.0
33+
34+ # --- Dataset & Tokenizer ---
35+ hf_path : " OptimalScale/ClimbMix"
36+ dataset_type : " hf"
37+ tokenizer_path : " meta-llama/Llama-3.1-8B"
38+ tokenizer_type : " huggingface"
39+
40+ max_target_length : 2048
41+
42+ # --- Training Loop ---
43+ steps : 200000
44+ checkpoint_period : 2000
45+ log_period : 10
46+ save_checkpoint_on_completion : True
47+
48+ # --- Batch Size Strategy ---
49+ # Global Batch Size = per_device_batch_size * num_devices * gradient_accumulation_steps
50+ per_device_batch_size : 2
51+ gradient_accumulation_steps : 8
52+
53+ # --- Learning Rate Schedule ---
54+ learning_rate : 2.0e-4
55+ learning_rate_schedule_steps : 200000
56+ warmup_steps_fraction : 0.1
57+ cosine_learning_rate_final_fraction : 0.1
0 commit comments