Lint schedulers

This commit is contained in:
NanoCode012
2023-05-29 14:25:15 +09:00
parent dae14e5951
commit fe1f4c4e7d

View File

@@ -1,7 +1,13 @@
"""Module for custom LRScheduler class"""
from torch.optim.lr_scheduler import LRScheduler
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion
"""
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
@@ -19,7 +25,9 @@ class InterpolatingLogScheduler(LRScheduler):
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name
1 / (num_steps - 1)
)
super().__init__(optimizer, last_epoch)
def get_lr(self):