Skip to content

Commit 72e96f5

Browse files
Merge pull request #3095 from AI-Hypercomputer:jimmytsai/fix-learning-rate-schedule
PiperOrigin-RevId: 878971906
2 parents c24d321 + 31015aa commit 72e96f5

2 files changed

Lines changed: 11 additions & 1 deletion

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,7 @@ def schedule(step):
14321432
boundaries = []
14331433

14341434
if warmup_steps > 0:
1435-
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps - 1)
1435+
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
14361436
pieces.append(warmup_schedule)
14371437
boundaries.append(warmup_steps)
14381438

tests/unit/maxtext_utils_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,11 @@ def test_cosine_schedule(self):
816816
# Warmup phase: 0 -> peak
817817
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
818818
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
819+
# Ensure delta is constant
820+
expected_slope = learning_rate / warmup_steps
821+
for i in range(1, warmup_steps + 1):
822+
current_lr = float(schedule_fn(i))
823+
self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6)
819824

820825
# Cosine decay phase
821826
lr_end = schedule_fn(learning_rate_schedule_steps - 1)
@@ -859,6 +864,11 @@ def test_wsd_schedule(self):
859864
# Warmup phase: 0 -> peak
860865
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
861866
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
867+
# Ensure delta is constant
868+
expected_slope = learning_rate / warmup_steps
869+
for i in range(1, warmup_steps + 1):
870+
current_lr = float(schedule_fn(i))
871+
self.assertAlmostEqual(current_lr - float(schedule_fn(i - 1)), expected_slope, places=6)
862872

863873
# Stable phase: constant at peak
864874
self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)

0 commit comments

Comments
 (0)