Skip to content

Commit b6bea94

Browse files
committed
added a unit test + format
1 parent 0d80d3d commit b6bea94

2 files changed

Lines changed: 63 additions & 4 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ def create_forward_fn(config: pyconfig.HyperParameters) -> Callable[..., distill
131131
"""
132132

133133
def model_forward_fn(
134-
model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None,
135-
**kwargs
134+
model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs
136135
) -> distillation_utils.DistillationForwardOutput:
137136
"""Forward pass wrapper adapted for raw MaxText models."""
138137
del attention_mask # Unused
@@ -215,7 +214,7 @@ def _train_step(self, model, optimizer, inputs):
215214

216215
batch = self.gen_model_input_fn(inputs)
217216

218-
def loss_wrapper(student, teacher, batch):
217+
def loss_wrapper(student, teacher, batch):
219218
if "teacher_output" in batch:
220219
teacher_output = batch["teacher_output"]
221220
else:
@@ -440,7 +439,7 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
440439

441440
# 3. Define Distillation Strategy
442441
def labels_fn(targets, targets_segmentation=None, **kwargs):
443-
"""Converts integer targets to masked one-hot vectors for hard label loss."""
442+
"""Converts integer targets to masked one-hot vectors for hard label loss."""
444443
del kwargs # Unused
445444
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
446445
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]

tests/unit/train_distill_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,66 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
249249
self.assertEqual(loss, mock_loss)
250250
self.assertEqual(aux, mock_aux)
251251

252+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map")
253+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad")
254+
def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map):
255+
"""Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask."""
256+
# 1. Initialize Trainer
257+
# pylint: disable=no-value-for-parameter
258+
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
259+
trainer.strategy = mock.Mock()
260+
261+
# 2. Setup Batch WITH targets_segmentation
262+
mock_targets_segmentation = jnp.array([[1, 1, 0]])
263+
mock_batch = {
264+
"input_tokens": mock.Mock(),
265+
"positions": mock.Mock(),
266+
"attention_mask": mock.Mock(),
267+
"decoder_segment_ids": mock.Mock(),
268+
"targets": mock.Mock(),
269+
"targets_segmentation": mock_targets_segmentation,
270+
}
271+
trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch)
272+
273+
# 3. Setup Models & Inputs
274+
teacher_model, student_model = mock.Mock(), mock.Mock()
275+
model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model)
276+
optimizer, inputs = mock.Mock(), mock.Mock()
277+
278+
# 4. Configure mocked nnx.value_and_grad
279+
mock_grad_fn = mock.Mock(return_value=((mock.Mock(), mock.Mock()), mock.Mock()))
280+
mock_value_and_grad.return_value = mock_grad_fn
281+
282+
# 5. Execute outer function & trigger inner loss_wrapper
283+
trainer._train_step(model_bundle, optimizer, inputs)
284+
loss_wrapper = mock_value_and_grad.call_args[0][0]
285+
loss_wrapper(student_model, teacher_model, mock_batch)
286+
287+
# 6. Assertions
288+
trainer.strategy.labels_fn.assert_called_once_with(
289+
mock_batch["targets"], targets_segmentation=mock_targets_segmentation
290+
)
291+
trainer.strategy.student_forward_fn.assert_called_once_with(
292+
model=student_model,
293+
input_tokens=mock_batch["input_tokens"],
294+
positions=mock_batch["positions"],
295+
attention_mask=mock_batch["attention_mask"],
296+
decoder_segment_ids=mock_batch["decoder_segment_ids"],
297+
decoder_target_tokens=mock_batch["targets"],
298+
decoder_target_mask=mock_targets_segmentation,
299+
cache=None,
300+
)
301+
trainer.strategy.teacher_forward_fn.assert_called_once_with(
302+
model=teacher_model,
303+
input_tokens=mock_batch["input_tokens"],
304+
positions=mock_batch["positions"],
305+
attention_mask=mock_batch["attention_mask"],
306+
decoder_segment_ids=mock_batch["decoder_segment_ids"],
307+
decoder_target_tokens=mock_batch["targets"],
308+
decoder_target_mask=mock_targets_segmentation,
309+
cache=None,
310+
)
311+
252312
def test_optimizer_factory(self):
253313
"""Verifies the optimizer factory injects hyperparams and handles configs."""
254314
# Mock config

0 commit comments

Comments
 (0)