Skip to content

Commit d486003

Browse files
committed
moved cmd args into the distillation config to make command easier to use
1 parent f61a2e5 commit d486003

2 files changed

Lines changed: 14 additions & 22 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,16 @@ class Distillation(BaseModel):
10811081
default_factory=dict,
10821082
description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64}).",
10831083
)
1084+
1085+
# --- Offline Distillation Fields ---
1086+
offline_distillation: bool = Field(
1087+
False,
1088+
description="If True, enables offline distillation using pre-generated teacher data."
1089+
)
1090+
offline_data_dir: Optional[str] = Field(
1091+
None,
1092+
description="GCS or local path to the pre-generated ArrayRecord teacher data."
1093+
)
10841094

10851095
# --- Loss Params ---
10861096
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def custom_gen_model_input_fn(batch):
619619
max_logging.log("Distillation Complete.")
620620

621621

622-
def main(argv: Sequence[str], local_args) -> None:
622+
def main(argv: Sequence[str]) -> None:
623623
"""Entry point for the script.
624624
625625
Parses configuration, isolates Student and Teacher overrides, and triggers the
@@ -638,7 +638,7 @@ def main(argv: Sequence[str], local_args) -> None:
638638
teacher_overrides = global_config.teacher_overrides
639639

640640
# Ensure load_parameters_path is set in overrides
641-
if not local_args.offline_distillation and not teacher_overrides.get("load_parameters_path"):
641+
if not global_config.offline_distillation and not teacher_overrides.get("load_parameters_path"):
642642
raise ValueError(
643643
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
644644
"in your config or arguments."
@@ -650,26 +650,8 @@ def main(argv: Sequence[str], local_args) -> None:
650650
teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides)
651651

652652
# 4. Run Training
653-
train_distill(student_config, teacher_config, local_args.offline_distillation, local_args.offline_data_dir)
653+
train_distill(student_config, teacher_config, global_config.offline_distillation, global_config.offline_data_dir)
654654

655655

656656
if __name__ == "__main__":
657-
parser = argparse.ArgumentParser()
658-
parser.add_argument(
659-
"--offline_distillation",
660-
action="store_true",
661-
help="Pass this flag to enable offline distillation.",
662-
)
663-
parser.add_argument(
664-
"--offline_data_dir",
665-
type=str,
666-
required=False,
667-
default=None,
668-
help="GCS or local path to the pre-generated ArrayRecord teacher data.",
669-
)
670-
671-
# parse_known_args separates our custom flags from MaxText's standard args
672-
local_arg, remaining_args = parser.parse_known_args()
673-
674-
main_wrapper = functools.partial(main, local_args=local_arg)
675-
app.run(main_wrapper, argv=[sys.argv[0]] + remaining_args)
657+
app.run(main)

0 commit comments

Comments
 (0)