Skip to content

Commit 22de090

Browse files
committed
Distillation optimizer fix
1 parent 093ab89 commit 22de090

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,11 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
199199
"""
200200

201201
def __init__(self, model, strategy, optimizer, training_config, **kwargs):
202-
super().__init__(model=model, optimizer=optimizer, training_config=training_config, **kwargs)
202+
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
203+
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
204+
# redefining the trainer optimizer here.
205+
dummy_optimizer = optax.set_to_zero()
206+
super().__init__(model=model, optimizer=dummy_optimizer, training_config=training_config, **kwargs)
203207

204208
self.strategy = strategy
205209

0 commit comments

Comments
 (0)