Skip to content

Commit d5520a6

Browse files
Merge pull request #3393 from AI-Hypercomputer:vladk/sft-completion-fix3
PiperOrigin-RevId: 882794487
2 parents ef98ae5 + 5d36835 commit d5520a6

2 files changed

Lines changed: 75 additions & 4 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def model_forward_fn(
141141
decoder_positions=positions,
142142
decoder_segment_ids=decoder_segment_ids,
143143
enable_dropout=config.enable_dropout,
144-
decoder_target_tokens=kwargs.get("targets", None),
145-
decoder_target_mask=kwargs.get("targets_segmentation", None),
144+
decoder_target_tokens=kwargs.get("decoder_target_tokens", None),
145+
decoder_target_mask=kwargs.get("decoder_target_mask", None),
146146
)
147147
out_projection_activations = None
148148
if config.distill_beta > 0.0:
@@ -224,6 +224,8 @@ def loss_wrapper(student, teacher, batch):
224224
positions=batch["positions"],
225225
attention_mask=batch.get("attention_mask"),
226226
decoder_segment_ids=batch.get("decoder_segment_ids"),
227+
decoder_target_tokens=batch.get("targets", None),
228+
decoder_target_mask=batch.get("targets_segmentation", None),
227229
cache=None,
228230
)
229231

@@ -235,9 +237,12 @@ def loss_wrapper(student, teacher, batch):
235237
positions=batch["positions"],
236238
attention_mask=batch.get("attention_mask"),
237239
decoder_segment_ids=batch.get("decoder_segment_ids"),
240+
decoder_target_tokens=batch.get("targets", None),
241+
decoder_target_mask=batch.get("targets_segmentation", None),
238242
cache=None,
239243
)
240-
labels = self.strategy.labels_fn(batch["targets"])
244+
# we should apply a mask for labels to disable segment-separator tokens
245+
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
241246
return self.strategy.compute_loss(student_output, teacher_output, labels)
242247

243248
# Because student is the 0th argument, argnums=0 guarantees

tests/unit/train_distill_test.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a
183183
positions=mock_batch["positions"],
184184
attention_mask=mock_batch["attention_mask"],
185185
decoder_segment_ids=mock_batch["decoder_segment_ids"],
186+
decoder_target_tokens=mock_batch.get("targets", None),
187+
decoder_target_mask=mock_batch.get("targets_segmentation", None),
186188
cache=None,
187189
)
188190

@@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
228230
positions=mock_batch["positions"],
229231
attention_mask=mock_batch["attention_mask"],
230232
decoder_segment_ids=mock_batch["decoder_segment_ids"],
233+
decoder_target_tokens=mock_batch.get("targets", None),
231234
cache=None,
235+
decoder_target_mask=None,
232236
)
233237

234238
trainer.strategy.student_forward_fn.assert_called_once_with(
@@ -237,18 +241,80 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
237241
positions=mock_batch["positions"],
238242
attention_mask=mock_batch["attention_mask"],
239243
decoder_segment_ids=mock_batch["decoder_segment_ids"],
244+
decoder_target_tokens=mock_batch.get("targets", None),
240245
cache=None,
246+
decoder_target_mask=None,
241247
)
242248

243249
# Verify loss computation and optimizer update
244-
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"])
250+
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
245251
trainer.strategy.compute_loss.assert_called_once()
246252
optimizer.update.assert_called_once_with(student_model, mock_grads)
247253

248254
# Verify the final returns match what grad_fn produced
249255
self.assertEqual(loss, mock_loss)
250256
self.assertEqual(aux, mock_aux)
251257

258+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map")
259+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad")
260+
def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map):
261+
"""Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask."""
262+
# 1. Initialize Trainer
263+
# pylint: disable=no-value-for-parameter
264+
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
265+
trainer.strategy = mock.Mock()
266+
267+
# 2. Setup Batch WITH targets_segmentation
268+
mock_targets_segmentation = jnp.array([[1, 1, 0]])
269+
mock_batch = {
270+
"input_tokens": mock.Mock(),
271+
"positions": mock.Mock(),
272+
"attention_mask": mock.Mock(),
273+
"decoder_segment_ids": mock.Mock(),
274+
"targets": mock.Mock(),
275+
"targets_segmentation": mock_targets_segmentation,
276+
}
277+
trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch)
278+
279+
# 3. Setup Models & Inputs
280+
teacher_model, student_model = mock.Mock(), mock.Mock()
281+
model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model)
282+
optimizer, inputs = mock.Mock(), mock.Mock()
283+
284+
# 4. Configure mocked nnx.value_and_grad
285+
mock_grad_fn = mock.Mock(return_value=((mock.Mock(), mock.Mock()), mock.Mock()))
286+
mock_value_and_grad.return_value = mock_grad_fn
287+
288+
# 5. Execute outer function & trigger inner loss_wrapper
289+
trainer._train_step(model_bundle, optimizer, inputs)
290+
loss_wrapper = mock_value_and_grad.call_args[0][0]
291+
loss_wrapper(student_model, teacher_model, mock_batch)
292+
293+
# 6. Assertions
294+
trainer.strategy.labels_fn.assert_called_once_with(
295+
mock_batch["targets"], targets_segmentation=mock_targets_segmentation
296+
)
297+
trainer.strategy.student_forward_fn.assert_called_once_with(
298+
model=student_model,
299+
input_tokens=mock_batch["input_tokens"],
300+
positions=mock_batch["positions"],
301+
attention_mask=mock_batch["attention_mask"],
302+
decoder_segment_ids=mock_batch["decoder_segment_ids"],
303+
decoder_target_tokens=mock_batch["targets"],
304+
decoder_target_mask=mock_targets_segmentation,
305+
cache=None,
306+
)
307+
trainer.strategy.teacher_forward_fn.assert_called_once_with(
308+
model=teacher_model,
309+
input_tokens=mock_batch["input_tokens"],
310+
positions=mock_batch["positions"],
311+
attention_mask=mock_batch["attention_mask"],
312+
decoder_segment_ids=mock_batch["decoder_segment_ids"],
313+
decoder_target_tokens=mock_batch["targets"],
314+
decoder_target_mask=mock_targets_segmentation,
315+
cache=None,
316+
)
317+
252318
def test_optimizer_factory(self):
253319
"""Verifies the optimizer factory injects hyperparams and handles configs."""
254320
# Mock config

0 commit comments

Comments
 (0)