@@ -671,6 +671,162 @@ def __call__(self, x):
671671 # Verify weights HAVE changed
672672 with self .assertRaises (AssertionError , msg = "Weights should have updated on the second pass." ):
673673 np .testing .assert_allclose (student .linear .kernel .value , initial_weights )
674+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator" )
675+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer" )
676+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator" )
677+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model" )
678+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer" )
679+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh" )
680+ @mock .patch ("maxtext.configs.pyconfig.initialize" )
681+ def test_main_offline_mode_skips_teacher_loading (
682+ self ,
683+ mock_pyconfig_init ,
684+ mock_create_mesh ,
685+ mock_build_tokenizer ,
686+ mock_get_model ,
687+ mock_create_iterator ,
688+ mock_trainer_cls ,
689+ mock_offline_iter_cls ,
690+ ):
691+ """Verifies offline mode (offline_data_dir is set) skips teacher model loading."""
692+ # 1. Configs
693+ mock_global = mock .Mock ()
694+ mock_global .student_overrides = {}
695+ mock_global .teacher_overrides = {} # No checkpoint needed
696+ mock_global .offline_data_dir = "gs://bucket/data" # Triggers offline mode
697+
698+ mock_student_cfg = mock .Mock ()
699+ mock_student_cfg .vocab_size = 32000
700+ mock_student_cfg .mesh_axes = ("data" ,)
701+ mock_student_cfg .dataset_type = "grain"
702+
703+ # Add dummy numbers for optimizer math
704+ mock_student_cfg .learning_rate = 1e-4
705+ mock_student_cfg .warmup_steps_fraction = 0.1
706+ mock_student_cfg .learning_rate_final_fraction = 0.1
707+ mock_student_cfg .steps = 100
708+ mock_student_cfg .checkpoint_period = 10
709+ mock_student_cfg .gradient_clipping_threshold = 0.0
710+ mock_student_cfg .eval_interval = - 1
711+
712+ # Add dummy numbers for strategy math/logic
713+ mock_student_cfg .distill_temperature = 1.0
714+ mock_student_cfg .distill_alpha = 0.5
715+ mock_student_cfg .distill_beta = 0.0
716+ mock_student_cfg .distill_layer_indices = None
717+ mock_student_cfg .use_sft = False
718+ mock_student_cfg .enable_dropout = False
719+
720+ # Add dummy variables for Checkpointer and Logger
721+ mock_student_cfg .max_num_checkpoints_to_keep = 1
722+ mock_student_cfg .async_checkpointing = False
723+ mock_student_cfg .profiler = "none"
724+ mock_student_cfg .tensorboard_dir = ""
725+ mock_student_cfg .checkpoint_dir = ""
726+ mock_student_cfg .log_period = 10
727+ mock_student_cfg .save_checkpoint_on_completion = False
728+ mock_student_cfg .logical_axis_rules = []
729+
730+ mock_teacher_cfg = mock .Mock ()
731+ mock_teacher_cfg .vocab_size = 32000
732+ mock_pyconfig_init .side_effect = [mock_global , mock_student_cfg , mock_teacher_cfg ]
733+
734+ # 2. Model Loading
735+ mock_student_model = mock .Mock ()
736+ mock_get_model .return_value = mock_student_model
737+
738+ # 3. Tokenizer & Data Iterator
739+ mock_build_tokenizer .return_value = mock .Mock (pad_id = 0 )
740+ mock_create_iterator .return_value = (None , None )
741+
742+ train_distill .main (["train_distill.py" , "config.yml" ])
743+
744+ # 4. Assertions
745+ # checking to ensure get_maxtext_model is only called once for student and not for teacher
746+ mock_get_model .assert_called_once_with (mock_student_cfg , mock .ANY )
747+
748+ trainer_init_kwargs = mock_trainer_cls .call_args .kwargs
749+ model_bundle = trainer_init_kwargs ["model" ]
750+ # check that student model is set but teacher model is None since offline mode should skip loading teacher
751+ self .assertIs (model_bundle .student_model , mock_student_model )
752+ self .assertIsNone (model_bundle .teacher_model )
753+
754+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer" )
755+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator" )
756+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model" )
757+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer" )
758+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh" )
759+ @mock .patch ("maxtext.configs.pyconfig.initialize" )
760+ def test_main_online_mode_loads_teacher (
761+ self ,
762+ mock_pyconfig_init ,
763+ mock_create_mesh ,
764+ mock_build_tokenizer ,
765+ mock_get_model ,
766+ mock_create_iterator ,
767+ mock_trainer_cls ,
768+ ):
769+ """Verifies online mode (offline_data_dir is None) loads both student and teacher models."""
770+ mock_global = mock .Mock ()
771+ mock_global .student_overrides = {}
772+ mock_global .teacher_overrides = {"load_parameters_path" : "gs://ckpt" }
773+ mock_global .offline_data_dir = None # Triggers online mode
774+
775+ mock_student_cfg = mock .Mock ()
776+ mock_student_cfg .vocab_size = 32000
777+ mock_student_cfg .mesh_axes = ("data" ,)
778+ mock_student_cfg .dataset_type = "grain"
779+
780+ # Add dummy numbers for optimizer math
781+ mock_student_cfg .learning_rate = 1e-4
782+ mock_student_cfg .warmup_steps_fraction = 0.1
783+ mock_student_cfg .learning_rate_final_fraction = 0.1
784+ mock_student_cfg .steps = 100
785+ mock_student_cfg .checkpoint_period = 10
786+ mock_student_cfg .gradient_clipping_threshold = 0.0
787+ mock_student_cfg .eval_interval = - 1
788+
789+ # Add dummy numbers for strategy math/logic
790+ mock_student_cfg .distill_temperature = 1.0
791+ mock_student_cfg .distill_alpha = 0.5
792+ mock_student_cfg .distill_beta = 0.0
793+ mock_student_cfg .distill_layer_indices = None
794+ mock_student_cfg .use_sft = False
795+ mock_student_cfg .enable_dropout = False
796+
797+ # Add dummy variables for Checkpointer and Logger
798+ mock_student_cfg .max_num_checkpoints_to_keep = 1
799+ mock_student_cfg .async_checkpointing = False
800+ mock_student_cfg .profiler = "none"
801+ mock_student_cfg .tensorboard_dir = ""
802+ mock_student_cfg .checkpoint_dir = ""
803+ mock_student_cfg .log_period = 10
804+ mock_student_cfg .save_checkpoint_on_completion = False
805+ mock_student_cfg .logical_axis_rules = []
806+
807+ mock_teacher_cfg = mock .Mock ()
808+ mock_teacher_cfg .vocab_size = 32000
809+ mock_pyconfig_init .side_effect = [mock_global , mock_student_cfg , mock_teacher_cfg ]
810+
811+ mock_student_model = mock .Mock ()
812+ mock_teacher_model = mock .Mock ()
813+ mock_get_model .side_effect = [mock_student_model , mock_teacher_model ]
814+
815+ mock_build_tokenizer .return_value = mock .Mock (pad_id = 0 )
816+ mock_create_iterator .return_value = (mock .Mock (), mock .Mock ())
817+
818+ train_distill .main (["train_distill.py" , "config.yml" ])
819+
820+ # checking to ensure get_maxtext_model is called for both student and teacher since online mode should load both
821+ self .assertEqual (mock_get_model .call_count , 2 )
822+ mock_get_model .assert_any_call (mock_student_cfg , mock .ANY )
823+ mock_get_model .assert_any_call (mock_teacher_cfg , mock .ANY )
824+
825+ trainer_init_kwargs = mock_trainer_cls .call_args .kwargs
826+ model_bundle = trainer_init_kwargs ["model" ]
827+ # check that both student and teacher models are set since online mode should load both
828+ self .assertIs (model_bundle .student_model , mock_student_model )
829+ self .assertIs (model_bundle .teacher_model , mock_teacher_model )
674830
675831
676832if __name__ == "__main__" :
0 commit comments