|
26 | 26 | from unittest import mock |
27 | 27 | import jax |
28 | 28 | import jax.numpy as jnp |
| 29 | +from flax import nnx |
29 | 30 | import numpy as np |
30 | 31 | import optax |
31 | 32 | import orbax.checkpoint as ocp |
@@ -537,6 +538,74 @@ def test_post_process_train_step(self): |
537 | 538 | values_list = mock_buffer.additional_metrics["distill/kl_div"][0] |
538 | 539 | self.assertEqual(values_list[0], 0.5) |
539 | 540 |
|
| 541 | + def test_gradient_accumulation_requires_k_passes_for_update(self): |
| 542 | + """Verifies that weights only update after k distinct forward passes.""" |
| 543 | + |
| 544 | + # 1. Setup a minimal NNX model |
| 545 | + class DummyModel(nnx.Module): |
| 546 | + |
| 547 | + def __init__(self): |
| 548 | + self.linear = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(0)) |
| 549 | + |
| 550 | + def __call__(self, x): |
| 551 | + return self.linear(x) |
| 552 | + |
| 553 | + student = DummyModel() |
| 554 | + teacher = DummyModel() |
| 555 | + model_bundle = train_distill.ModelBundle(teacher_model=teacher, student_model=student) |
| 556 | + |
| 557 | + # Snapshot the initial weights |
| 558 | + initial_weights = student.linear.kernel.value.copy() |
| 559 | + |
| 560 | + # 2. Setup Optimizer with MultiSteps (Accumulate over 2 passes) |
| 561 | + base_optimizer = optax.sgd(learning_rate=0.1) |
| 562 | + accumulating_optimizer = optax.MultiSteps(base_optimizer, every_k_schedule=2) |
| 563 | + nnx_opt = nnx.Optimizer(student, accumulating_optimizer, wrt=nnx.Param) |
| 564 | + |
| 565 | + # 3. Initialize Trainer and Mocks |
| 566 | + # pylint: disable=no-value-for-parameter |
| 567 | + trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) |
| 568 | + trainer.strategy = mock.Mock() |
| 569 | + |
| 570 | + dummy_batch = { |
| 571 | + "input_tokens": jnp.ones((1, 2)), |
| 572 | + "positions": None, |
| 573 | + "targets": None, |
| 574 | + "teacher_output": jnp.array([1.0, 1.0]), |
| 575 | + } |
| 576 | + trainer.gen_model_input_fn = mock.Mock(return_value=dummy_batch) |
| 577 | + trainer.strategy.labels_fn.return_value = None |
| 578 | + |
| 579 | + # 4. Mock the forward pass to COUNT how many times it executes |
| 580 | + # We wrap the actual dummy model execution in a mock to track it. |
| 581 | + mock_student_forward = mock.Mock(side_effect=lambda model, **kwargs: model(dummy_batch["input_tokens"])) |
| 582 | + trainer.strategy.student_forward_fn = mock_student_forward |
| 583 | + |
| 584 | + trainer.strategy.compute_loss.side_effect = lambda s_out, t_out, labels: (jnp.sum(s_out), {"aux": 1.0}) |
| 585 | + |
| 586 | + # --- EXECUTE PASS 1 --- |
| 587 | + trainer._train_step(model_bundle, nnx_opt, dummy_batch) |
| 588 | + |
| 589 | + # ASSERTIONS AFTER PASS 1: |
| 590 | + # Verify exactly ONE forward pass happened |
| 591 | + self.assertEqual(mock_student_forward.call_count, 1) |
| 592 | + |
| 593 | + # Verify weights are completely UNCHANGED |
| 594 | + np.testing.assert_allclose( |
| 595 | + student.linear.kernel.value, initial_weights, err_msg="Weights should not update on the first pass." |
| 596 | + ) |
| 597 | + |
| 598 | + # --- EXECUTE PASS 2 --- |
| 599 | + trainer._train_step(model_bundle, nnx_opt, dummy_batch) |
| 600 | + |
| 601 | + # ASSERTIONS AFTER PASS 2: |
| 602 | + # Verify exactly TWO forward passes have now happened |
| 603 | + self.assertEqual(mock_student_forward.call_count, 2) |
| 604 | + |
| 605 | + # Verify weights HAVE changed |
| 606 | + with self.assertRaises(AssertionError, msg="Weights should have updated on the second pass."): |
| 607 | + np.testing.assert_allclose(student.linear.kernel.value, initial_weights) |
| 608 | + |
540 | 609 |
|
541 | 610 | if __name__ == "__main__": |
542 | 611 | absltest.main() |
0 commit comments