Skip to content

Commit beb19b9

Browse files
committed
added unit tests to make sure offline and online distillation is loading the correct models
1 parent cc4b912 commit beb19b9

1 file changed

Lines changed: 156 additions & 0 deletions

File tree

tests/post_training/unit/train_distill_test.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,162 @@ def __call__(self, x):
671671
# Verify weights HAVE changed
672672
with self.assertRaises(AssertionError, msg="Weights should have updated on the second pass."):
673673
np.testing.assert_allclose(student.linear.kernel.value, initial_weights)
674+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator")
675+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer")
676+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator")
677+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model")
678+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer")
679+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh")
680+
@mock.patch("maxtext.configs.pyconfig.initialize")
681+
def test_main_offline_mode_skips_teacher_loading(
682+
self,
683+
mock_pyconfig_init,
684+
mock_create_mesh,
685+
mock_build_tokenizer,
686+
mock_get_model,
687+
mock_create_iterator,
688+
mock_trainer_cls,
689+
mock_offline_iter_cls,
690+
):
691+
"""Verifies offline mode (offline_data_dir is set) skips teacher model loading."""
692+
# 1. Configs
693+
mock_global = mock.Mock()
694+
mock_global.student_overrides = {}
695+
mock_global.teacher_overrides = {} # No checkpoint needed
696+
mock_global.offline_data_dir = "gs://bucket/data" # Triggers offline mode
697+
698+
mock_student_cfg = mock.Mock()
699+
mock_student_cfg.vocab_size = 32000
700+
mock_student_cfg.mesh_axes = ("data",)
701+
mock_student_cfg.dataset_type = "grain"
702+
703+
# Add dummy numbers for optimizer math
704+
mock_student_cfg.learning_rate = 1e-4
705+
mock_student_cfg.warmup_steps_fraction = 0.1
706+
mock_student_cfg.learning_rate_final_fraction = 0.1
707+
mock_student_cfg.steps = 100
708+
mock_student_cfg.checkpoint_period = 10
709+
mock_student_cfg.gradient_clipping_threshold = 0.0
710+
mock_student_cfg.eval_interval = -1
711+
712+
# Add dummy numbers for strategy math/logic
713+
mock_student_cfg.distill_temperature = 1.0
714+
mock_student_cfg.distill_alpha = 0.5
715+
mock_student_cfg.distill_beta = 0.0
716+
mock_student_cfg.distill_layer_indices = None
717+
mock_student_cfg.use_sft = False
718+
mock_student_cfg.enable_dropout = False
719+
720+
# Add dummy variables for Checkpointer and Logger
721+
mock_student_cfg.max_num_checkpoints_to_keep = 1
722+
mock_student_cfg.async_checkpointing = False
723+
mock_student_cfg.profiler = "none"
724+
mock_student_cfg.tensorboard_dir = ""
725+
mock_student_cfg.checkpoint_dir = ""
726+
mock_student_cfg.log_period = 10
727+
mock_student_cfg.save_checkpoint_on_completion = False
728+
mock_student_cfg.logical_axis_rules = []
729+
730+
mock_teacher_cfg = mock.Mock()
731+
mock_teacher_cfg.vocab_size = 32000
732+
mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg]
733+
734+
# 2. Model Loading
735+
mock_student_model = mock.Mock()
736+
mock_get_model.return_value = mock_student_model
737+
738+
# 3. Tokenizer & Data Iterator
739+
mock_build_tokenizer.return_value = mock.Mock(pad_id=0)
740+
mock_create_iterator.return_value = (None, None)
741+
742+
train_distill.main(["train_distill.py", "config.yml"])
743+
744+
# 4. Assertions
745+
# checking to ensure get_maxtext_model is only called once for student and not for teacher
746+
mock_get_model.assert_called_once_with(mock_student_cfg, mock.ANY)
747+
748+
trainer_init_kwargs = mock_trainer_cls.call_args.kwargs
749+
model_bundle = trainer_init_kwargs["model"]
750+
# check that student model is set but teacher model is None since offline mode should skip loading teacher
751+
self.assertIs(model_bundle.student_model, mock_student_model)
752+
self.assertIsNone(model_bundle.teacher_model)
753+
754+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer")
755+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator")
756+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model")
757+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer")
758+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh")
759+
@mock.patch("maxtext.configs.pyconfig.initialize")
760+
def test_main_online_mode_loads_teacher(
761+
self,
762+
mock_pyconfig_init,
763+
mock_create_mesh,
764+
mock_build_tokenizer,
765+
mock_get_model,
766+
mock_create_iterator,
767+
mock_trainer_cls,
768+
):
769+
"""Verifies online mode (offline_data_dir is None) loads both student and teacher models."""
770+
mock_global = mock.Mock()
771+
mock_global.student_overrides = {}
772+
mock_global.teacher_overrides = {"load_parameters_path": "gs://ckpt"}
773+
mock_global.offline_data_dir = None # Triggers online mode
774+
775+
mock_student_cfg = mock.Mock()
776+
mock_student_cfg.vocab_size = 32000
777+
mock_student_cfg.mesh_axes = ("data",)
778+
mock_student_cfg.dataset_type = "grain"
779+
780+
# Add dummy numbers for optimizer math
781+
mock_student_cfg.learning_rate = 1e-4
782+
mock_student_cfg.warmup_steps_fraction = 0.1
783+
mock_student_cfg.learning_rate_final_fraction = 0.1
784+
mock_student_cfg.steps = 100
785+
mock_student_cfg.checkpoint_period = 10
786+
mock_student_cfg.gradient_clipping_threshold = 0.0
787+
mock_student_cfg.eval_interval = -1
788+
789+
# Add dummy numbers for strategy math/logic
790+
mock_student_cfg.distill_temperature = 1.0
791+
mock_student_cfg.distill_alpha = 0.5
792+
mock_student_cfg.distill_beta = 0.0
793+
mock_student_cfg.distill_layer_indices = None
794+
mock_student_cfg.use_sft = False
795+
mock_student_cfg.enable_dropout = False
796+
797+
# Add dummy variables for Checkpointer and Logger
798+
mock_student_cfg.max_num_checkpoints_to_keep = 1
799+
mock_student_cfg.async_checkpointing = False
800+
mock_student_cfg.profiler = "none"
801+
mock_student_cfg.tensorboard_dir = ""
802+
mock_student_cfg.checkpoint_dir = ""
803+
mock_student_cfg.log_period = 10
804+
mock_student_cfg.save_checkpoint_on_completion = False
805+
mock_student_cfg.logical_axis_rules = []
806+
807+
mock_teacher_cfg = mock.Mock()
808+
mock_teacher_cfg.vocab_size = 32000
809+
mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg]
810+
811+
mock_student_model = mock.Mock()
812+
mock_teacher_model = mock.Mock()
813+
mock_get_model.side_effect = [mock_student_model, mock_teacher_model]
814+
815+
mock_build_tokenizer.return_value = mock.Mock(pad_id=0)
816+
mock_create_iterator.return_value = (mock.Mock(), mock.Mock())
817+
818+
train_distill.main(["train_distill.py", "config.yml"])
819+
820+
# checking to ensure get_maxtext_model is called for both student and teacher since online mode should load both
821+
self.assertEqual(mock_get_model.call_count, 2)
822+
mock_get_model.assert_any_call(mock_student_cfg, mock.ANY)
823+
mock_get_model.assert_any_call(mock_teacher_cfg, mock.ANY)
824+
825+
trainer_init_kwargs = mock_trainer_cls.call_args.kwargs
826+
model_bundle = trainer_init_kwargs["model"]
827+
# check that both student and teacher models are set since online mode should load both
828+
self.assertIs(model_bundle.student_model, mock_student_model)
829+
self.assertIs(model_bundle.teacher_model, mock_teacher_model)
674830

675831

676832
if __name__ == "__main__":

0 commit comments

Comments
 (0)