add optimizer step to prevent warning in tests (#1502) [skip ci]
* add optimizer step to prevent warning in tests * add optimizer step to warmup as well
This commit is contained in:
@@ -32,16 +32,19 @@ class TestCosineConstantLr(unittest.TestCase):
|
|||||||
def test_schedulers(self):
|
def test_schedulers(self):
|
||||||
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
|
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
|
||||||
for _ in range(self.warmup_steps):
|
for _ in range(self.warmup_steps):
|
||||||
|
self.optimizer.step()
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
|
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
|
||||||
constant_step = int(self.train_steps * self.constant_lr_ratio)
|
constant_step = int(self.train_steps * self.constant_lr_ratio)
|
||||||
remaining_step = self.train_steps - constant_step
|
remaining_step = self.train_steps - constant_step
|
||||||
for _ in range(constant_step):
|
for _ in range(constant_step):
|
||||||
|
self.optimizer.step()
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
||||||
)
|
)
|
||||||
for _ in range(remaining_step):
|
for _ in range(remaining_step):
|
||||||
|
self.optimizer.step()
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
||||||
|
|||||||
Reference in New Issue
Block a user