Skip to content

Commit 31015aa

Browse files
committed
Fix learning rate schedule
1 parent 95ef3e1 commit 31015aa

2 files changed

Lines changed: 11 additions & 1 deletion

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ def schedule(step):
13761376
boundaries = []
13771377

13781378
if warmup_steps > 0:
1379-
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps - 1)
1379+
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
13801380
pieces.append(warmup_schedule)
13811381
boundaries.append(warmup_steps)
13821382

tests/unit/maxtext_utils_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,11 @@ def test_cosine_schedule(self):
748748
# Warmup phase: 0 -> peak
749749
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
750750
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
751+
# Ensure delta is constant
752+
expected_slope = learning_rate / warmup_steps
753+
for i in range(1, warmup_steps + 1):
754+
current_lr = float(schedule_fn(i))
755+
self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6)
751756

752757
# Cosine decay phase
753758
lr_end = schedule_fn(learning_rate_schedule_steps - 1)
@@ -791,6 +796,11 @@ def test_wsd_schedule(self):
791796
# Warmup phase: 0 -> peak
792797
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
793798
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
799+
# Ensure delta is constant
800+
expected_slope = learning_rate / warmup_steps
801+
for i in range(1, warmup_steps + 1):
802+
current_lr = float(schedule_fn(i))
803+
self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6)
794804

795805
# Stable phase: constant at peak
796806
self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)

0 commit comments

Comments
 (0)