Skip to content

Commit cc4b912

Browse files
committed
Removed redundant offline_distillation flag and relied on offline_data_dir to know when to run offfline vs online distillation
1 parent 56e6f07 commit cc4b912

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,10 +1082,7 @@ class Distillation(BaseModel):
10821082
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
10831083
)
10841084

1085-
# --- Offline Distillation Fields ---
1086-
offline_distillation: bool = Field(
1087-
False, description="If True, enables offline distillation using pre-generated teacher data."
1088-
)
1085+
# --- Offline Distillation Field ---
10891086
offline_data_dir: Optional[str] = Field(
10901087
None, description="GCS or local path to the pre-generated ArrayRecord teacher data."
10911088
)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,14 @@ def main(argv: Sequence[str]) -> None:
629629
student_overrides = global_config.student_overrides
630630
student_config = pyconfig.initialize(argv, **student_overrides)
631631

632+
is_offline = bool(global_config.offline_data_dir)
633+
632634
# 3. Initialize TEACHER Config
633635
# We isolate the Teacher from Student CLI arguments (like pruning params).
634636
teacher_overrides = global_config.teacher_overrides
635637

636638
# Ensure load_parameters_path is set in overrides
637-
if not global_config.offline_distillation and not teacher_overrides.get("load_parameters_path"):
639+
if not is_offline and not teacher_overrides.get("load_parameters_path"):
638640
raise ValueError(
639641
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
640642
"in your config or arguments."
@@ -646,7 +648,7 @@ def main(argv: Sequence[str]) -> None:
646648
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
647649

648650
# 4. Run Training
649-
train_distill(student_config, teacher_config, global_config.offline_distillation, global_config.offline_data_dir)
651+
train_distill(student_config, teacher_config, is_offline, global_config.offline_data_dir)
650652

651653

652654
if __name__ == "__main__":

0 commit comments

Comments
 (0)