Skip to content

Commit 1a45bf0

Browse files
Merge pull request #3452 from AI-Hypercomputer:optimizer_ga_fix_2
PiperOrigin-RevId: 886886106
2 parents c7355aa + 4f6d05e commit 1a45bf0

3 files changed

Lines changed: 7 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ def compute_loss(
301301
"distill/teacher_loss": teacher_hard_loss,
302302
"distill/out_proj_feature_loss": feature_loss,
303303
"distill/total_loss": total_loss,
304+
"distill/temperature": self.temperature,
305+
"distill/alpha": self.alpha,
306+
"distill/beta_feature": self.beta_feature,
304307
}
305308
return total_loss, metrics
306309

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
480480
)
481481

482482
# 4. Optimizer & Config
483-
total_updates = student_config.steps // student_config.gradient_accumulation_steps
484-
optimizer = get_distillation_optimizer(student_config, total_updates)
483+
optimizer = get_distillation_optimizer(student_config, student_config.steps)
485484

486485
checkpointing_options = checkpoint.CheckpointManagerOptions(
487486
save_interval_steps=student_config.checkpoint_period,

tests/post_training/unit/train_distill_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def _test_monitored_strategy(self, sft_mode: bool):
399399
"distill/teacher_loss",
400400
"distill/out_proj_feature_loss",
401401
"distill/total_loss",
402+
"distill/temperature",
403+
"distill/alpha",
404+
"distill/beta_feature",
402405
]
403406
for key in expected_keys:
404407
self.assertIn(key, metrics)

0 commit comments

Comments
 (0)