Skip to content

Commit 4f6d05e

Browse files
committed
fix optimizer number of steps
1 parent 093ab89 commit 4f6d05e

3 files changed

Lines changed: 7 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def compute_loss(
248248
"distill/teacher_loss": teacher_hard_loss,
249249
"distill/out_proj_feature_loss": feature_loss,
250250
"distill/total_loss": total_loss,
251+
"distill/temperature": self.temperature,
252+
"distill/alpha": self.alpha,
253+
"distill/beta_feature": self.beta_feature,
251254
}
252255
return total_loss, metrics
253256

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
464464
)
465465

466466
# 4. Optimizer & Config
467-
total_updates = student_config.steps // student_config.gradient_accumulation_steps
468-
optimizer = get_distillation_optimizer(student_config, total_updates)
467+
optimizer = get_distillation_optimizer(student_config, student_config.steps)
469468

470469
checkpointing_options = checkpoint.CheckpointManagerOptions(
471470
save_interval_steps=student_config.checkpoint_period,

tests/post_training/unit/train_distill_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def _test_monitored_strategy(self, sft_mode: bool):
399399
"distill/teacher_loss",
400400
"distill/out_proj_feature_loss",
401401
"distill/total_loss",
402+
"distill/temperature",
403+
"distill/alpha",
404+
"distill/beta_feature",
402405
]
403406
for key in expected_keys:
404407
self.assertIn(key, metrics)

0 commit comments

Comments
 (0)