@@ -183,6 +183,8 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a
183183 positions = mock_batch ["positions" ],
184184 attention_mask = mock_batch ["attention_mask" ],
185185 decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
186+ decoder_target_tokens = mock_batch .get ("targets" , None ),
187+ decoder_target_mask = mock_batch .get ("targets_segmentation" , None ),
186188 cache = None ,
187189 )
188190
@@ -228,7 +230,9 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
228230 positions = mock_batch ["positions" ],
229231 attention_mask = mock_batch ["attention_mask" ],
230232 decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
233+ decoder_target_tokens = mock_batch .get ("targets" , None ),
231234 cache = None ,
235+ decoder_target_mask = None ,
232236 )
233237
234238 trainer .strategy .student_forward_fn .assert_called_once_with (
@@ -237,18 +241,80 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
237241 positions = mock_batch ["positions" ],
238242 attention_mask = mock_batch ["attention_mask" ],
239243 decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
244+ decoder_target_tokens = mock_batch .get ("targets" , None ),
240245 cache = None ,
246+ decoder_target_mask = None ,
241247 )
242248
243249 # Verify loss computation and optimizer update
244- trainer .strategy .labels_fn .assert_called_once_with (mock_batch ["targets" ])
250+ trainer .strategy .labels_fn .assert_called_once_with (mock_batch ["targets" ], targets_segmentation = None )
245251 trainer .strategy .compute_loss .assert_called_once ()
246252 optimizer .update .assert_called_once_with (student_model , mock_grads )
247253
248254 # Verify the final returns match what grad_fn produced
249255 self .assertEqual (loss , mock_loss )
250256 self .assertEqual (aux , mock_aux )
251257
258+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map" )
259+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad" )
260+ def test_train_step_passes_targets_segmentation (self , mock_value_and_grad , mock_tree_map ):
261+ """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask."""
262+ # 1. Initialize Trainer
263+ # pylint: disable=no-value-for-parameter
264+ trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
265+ trainer .strategy = mock .Mock ()
266+
267+ # 2. Setup Batch WITH targets_segmentation
268+ mock_targets_segmentation = jnp .array ([[1 , 1 , 0 ]])
269+ mock_batch = {
270+ "input_tokens" : mock .Mock (),
271+ "positions" : mock .Mock (),
272+ "attention_mask" : mock .Mock (),
273+ "decoder_segment_ids" : mock .Mock (),
274+ "targets" : mock .Mock (),
275+ "targets_segmentation" : mock_targets_segmentation ,
276+ }
277+ trainer .gen_model_input_fn = mock .Mock (return_value = mock_batch )
278+
279+ # 3. Setup Models & Inputs
280+ teacher_model , student_model = mock .Mock (), mock .Mock ()
281+ model_bundle = train_distill .ModelBundle (teacher_model = teacher_model , student_model = student_model )
282+ optimizer , inputs = mock .Mock (), mock .Mock ()
283+
284+ # 4. Configure mocked nnx.value_and_grad
285+ mock_grad_fn = mock .Mock (return_value = ((mock .Mock (), mock .Mock ()), mock .Mock ()))
286+ mock_value_and_grad .return_value = mock_grad_fn
287+
288+ # 5. Execute outer function & trigger inner loss_wrapper
289+ trainer ._train_step (model_bundle , optimizer , inputs )
290+ loss_wrapper = mock_value_and_grad .call_args [0 ][0 ]
291+ loss_wrapper (student_model , teacher_model , mock_batch )
292+
293+ # 6. Assertions
294+ trainer .strategy .labels_fn .assert_called_once_with (
295+ mock_batch ["targets" ], targets_segmentation = mock_targets_segmentation
296+ )
297+ trainer .strategy .student_forward_fn .assert_called_once_with (
298+ model = student_model ,
299+ input_tokens = mock_batch ["input_tokens" ],
300+ positions = mock_batch ["positions" ],
301+ attention_mask = mock_batch ["attention_mask" ],
302+ decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
303+ decoder_target_tokens = mock_batch ["targets" ],
304+ decoder_target_mask = mock_targets_segmentation ,
305+ cache = None ,
306+ )
307+ trainer .strategy .teacher_forward_fn .assert_called_once_with (
308+ model = teacher_model ,
309+ input_tokens = mock_batch ["input_tokens" ],
310+ positions = mock_batch ["positions" ],
311+ attention_mask = mock_batch ["attention_mask" ],
312+ decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
313+ decoder_target_tokens = mock_batch ["targets" ],
314+ decoder_target_mask = mock_targets_segmentation ,
315+ cache = None ,
316+ )
317+
252318 def test_optimizer_factory (self ):
253319 """Verifies the optimizer factory injects hyperparams and handles configs."""
254320 # Mock config
0 commit comments