@@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a
183183 positions = mock_batch ["positions" ],
184184 attention_mask = mock_batch ["attention_mask" ],
185185 decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
186+ decoder_target_tokens = mock_batch .get ("targets" , None ),
187+ decoder_target_mask = mock_batch .get ("targets_segmentation" , None ),
186188 cache = None ,
187189 )
188190
@@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
228230 positions = mock_batch ["positions" ],
229231 attention_mask = mock_batch ["attention_mask" ],
230232 decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
233+ decoder_target_tokens = mock_batch .get ("targets" , None ),
231234 cache = None ,
235+ decoder_target_mask = None ,
232236 )
233237
234238 trainer .strategy .student_forward_fn .assert_called_once_with (
@@ -237,11 +241,13 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
237241 positions = mock_batch ["positions" ],
238242 attention_mask = mock_batch ["attention_mask" ],
239243 decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
244+ decoder_target_tokens = mock_batch .get ("targets" , None ),
240245 cache = None ,
246+ decoder_target_mask = None ,
241247 )
242248
243249 # Verify loss computation and optimizer update
244- trainer .strategy .labels_fn .assert_called_once_with (mock_batch ["targets" ])
250+ trainer .strategy .labels_fn .assert_called_once_with (mock_batch ["targets" ], targets_segmentation = None )
245251 trainer .strategy .compute_loss .assert_called_once ()
246252 optimizer .update .assert_called_once_with (student_model , mock_grads )
247253
0 commit comments