Skip to content

Commit 5d36835

Browse files
committed
fixed related test
1 parent b6bea94 commit 5d36835

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

tests/unit/train_distill_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a
183183
positions=mock_batch["positions"],
184184
attention_mask=mock_batch["attention_mask"],
185185
decoder_segment_ids=mock_batch["decoder_segment_ids"],
186+
decoder_target_tokens=mock_batch.get("targets", None),
187+
decoder_target_mask=mock_batch.get("targets_segmentation", None),
186188
cache=None,
187189
)
188190

@@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
228230
positions=mock_batch["positions"],
229231
attention_mask=mock_batch["attention_mask"],
230232
decoder_segment_ids=mock_batch["decoder_segment_ids"],
233+
decoder_target_tokens=mock_batch.get("targets", None),
231234
cache=None,
235+
decoder_target_mask=None,
232236
)
233237

234238
trainer.strategy.student_forward_fn.assert_called_once_with(
@@ -237,11 +241,13 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
237241
positions=mock_batch["positions"],
238242
attention_mask=mock_batch["attention_mask"],
239243
decoder_segment_ids=mock_batch["decoder_segment_ids"],
244+
decoder_target_tokens=mock_batch.get("targets", None),
240245
cache=None,
246+
decoder_target_mask=None,
241247
)
242248

243249
# Verify loss computation and optimizer update
244-
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"])
250+
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
245251
trainer.strategy.compute_loss.assert_called_once()
246252
optimizer.update.assert_called_once_with(student_model, mock_grads)
247253

0 commit comments

Comments
 (0)