@@ -619,7 +619,7 @@ def custom_gen_model_input_fn(batch):
619619 max_logging .log ("Distillation Complete." )
620620
621621
622- def main (argv : Sequence [str ], local_args ) -> None :
622+ def main (argv : Sequence [str ]) -> None :
623623 """Entry point for the script.
624624
625625 Parses configuration, isolates Student and Teacher overrides, and triggers the
@@ -638,7 +638,7 @@ def main(argv: Sequence[str], local_args) -> None:
638638 teacher_overrides = global_config .teacher_overrides
639639
640640 # Ensure load_parameters_path is set in overrides
641- if not local_args .offline_distillation and not teacher_overrides .get ("load_parameters_path" ):
641+ if not global_config .offline_distillation and not teacher_overrides .get ("load_parameters_path" ):
642642 raise ValueError (
643643 "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
644644 "in your config or arguments."
@@ -650,26 +650,8 @@ def main(argv: Sequence[str], local_args) -> None:
650650 teacher_config = pyconfig .initialize (teacher_argv , ** teacher_overrides )
651651
652652 # 4. Run Training
653- train_distill (student_config , teacher_config , local_args .offline_distillation , local_args .offline_data_dir )
653+ train_distill (student_config , teacher_config , global_config .offline_distillation , global_config .offline_data_dir )
654654
655655
656656if __name__ == "__main__" :
657- parser = argparse .ArgumentParser ()
658- parser .add_argument (
659- "--offline_distillation" ,
660- action = "store_true" ,
661- help = "Pass this flag to enable offline distillation." ,
662- )
663- parser .add_argument (
664- "--offline_data_dir" ,
665- type = str ,
666- required = False ,
667- default = None ,
668- help = "GCS or local path to the pre-generated ArrayRecord teacher data." ,
669- )
670-
671- # parse_known_args separates our custom flags from MaxText's standard args
672- local_arg , remaining_args = parser .parse_known_args ()
673-
674- main_wrapper = functools .partial (main , local_args = local_arg )
675- app .run (main_wrapper , argv = [sys .argv [0 ]] + remaining_args )
657+ app .run (main )
0 commit comments