We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent beb19b9 commit f124f10Copy full SHA for f124f10
1 file changed
tests/post_training/unit/train_distill_test.py
@@ -671,6 +671,7 @@ def __call__(self, x):
671
# Verify weights HAVE changed
672
with self.assertRaises(AssertionError, msg="Weights should have updated on the second pass."):
673
np.testing.assert_allclose(student.linear.kernel.value, initial_weights)
674
+
675
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator")
676
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer")
677
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator")
0 commit comments