@@ -816,6 +816,11 @@ def test_cosine_schedule(self):
816816 # Warmup phase: 0 -> peak
817817 self .assertAlmostEqual (float (schedule_fn (0 )), 0.0 , places = 6 )
818818 self .assertAlmostEqual (float (schedule_fn (warmup_steps )), learning_rate , places = 6 )
819+ # Ensure delta is constant
820+ expected_slope = learning_rate / warmup_steps
821+ for i in range (1 , warmup_steps + 1 ):
822+ current_lr = float (schedule_fn (i ))
823+ self .assertAlmostEqual (current_lr - float (schedule_fn (i - 1 )), expected_slope , places = 6 )
819824
820825 # Cosine decay phase
821826 lr_end = schedule_fn (learning_rate_schedule_steps - 1 )
@@ -859,6 +864,11 @@ def test_wsd_schedule(self):
859864 # Warmup phase: 0 -> peak
860865 self .assertAlmostEqual (float (schedule_fn (0 )), 0.0 , places = 6 )
861866 self .assertAlmostEqual (float (schedule_fn (warmup_steps )), learning_rate , places = 6 )
867+ # Ensure delta is constant
868+ expected_slope = learning_rate / warmup_steps
869+ for i in range (1 , warmup_steps + 1 ):
870+ current_lr = float (schedule_fn (i ))
871+ self .assertAlmostEqual (current_lr - float (schedule_fn (i - 1 )), expected_slope , places = 6 )
862872
863873 # Stable phase: constant at peak
864874 self .assertAlmostEqual (float (schedule_fn (warmup_steps + 10 )), learning_rate , places = 6 )
0 commit comments