@@ -249,6 +249,66 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
249249 self .assertEqual (loss , mock_loss )
250250 self .assertEqual (aux , mock_aux )
251251
252+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map" )
253+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad" )
254+ def test_train_step_passes_targets_segmentation (self , mock_value_and_grad , mock_tree_map ):
255+ """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask."""
256+ # 1. Initialize Trainer
257+ # pylint: disable=no-value-for-parameter
258+ trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
259+ trainer .strategy = mock .Mock ()
260+
261+ # 2. Setup Batch WITH targets_segmentation
262+ mock_targets_segmentation = jnp .array ([[1 , 1 , 0 ]])
263+ mock_batch = {
264+ "input_tokens" : mock .Mock (),
265+ "positions" : mock .Mock (),
266+ "attention_mask" : mock .Mock (),
267+ "decoder_segment_ids" : mock .Mock (),
268+ "targets" : mock .Mock (),
269+ "targets_segmentation" : mock_targets_segmentation ,
270+ }
271+ trainer .gen_model_input_fn = mock .Mock (return_value = mock_batch )
272+
273+ # 3. Setup Models & Inputs
274+ teacher_model , student_model = mock .Mock (), mock .Mock ()
275+ model_bundle = train_distill .ModelBundle (teacher_model = teacher_model , student_model = student_model )
276+ optimizer , inputs = mock .Mock (), mock .Mock ()
277+
278+ # 4. Configure mocked nnx.value_and_grad
279+ mock_grad_fn = mock .Mock (return_value = ((mock .Mock (), mock .Mock ()), mock .Mock ()))
280+ mock_value_and_grad .return_value = mock_grad_fn
281+
282+ # 5. Execute outer function & trigger inner loss_wrapper
283+ trainer ._train_step (model_bundle , optimizer , inputs )
284+ loss_wrapper = mock_value_and_grad .call_args [0 ][0 ]
285+ loss_wrapper (student_model , teacher_model , mock_batch )
286+
287+ # 6. Assertions
288+ trainer .strategy .labels_fn .assert_called_once_with (
289+ mock_batch ["targets" ], targets_segmentation = mock_targets_segmentation
290+ )
291+ trainer .strategy .student_forward_fn .assert_called_once_with (
292+ model = student_model ,
293+ input_tokens = mock_batch ["input_tokens" ],
294+ positions = mock_batch ["positions" ],
295+ attention_mask = mock_batch ["attention_mask" ],
296+ decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
297+ decoder_target_tokens = mock_batch ["targets" ],
298+ decoder_target_mask = mock_targets_segmentation ,
299+ cache = None ,
300+ )
301+ trainer .strategy .teacher_forward_fn .assert_called_once_with (
302+ model = teacher_model ,
303+ input_tokens = mock_batch ["input_tokens" ],
304+ positions = mock_batch ["positions" ],
305+ attention_mask = mock_batch ["attention_mask" ],
306+ decoder_segment_ids = mock_batch ["decoder_segment_ids" ],
307+ decoder_target_tokens = mock_batch ["targets" ],
308+ decoder_target_mask = mock_targets_segmentation ,
309+ cache = None ,
310+ )
311+
252312 def test_optimizer_factory (self ):
253313 """Verifies the optimizer factory injects hyperparams and handles configs."""
254314 # Mock config
0 commit comments