fix: deepcopy lr in RexLR scheduler. (#3012)
* fix: deepcopy lr in RexLR scheduler. This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate. See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/ * fix: add missing torch.Tensor import
This commit is contained in:
committed by
GitHub
parent
a54c1be972
commit
33d094721c
@@ -4,6 +4,7 @@ import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
@@ -45,8 +46,10 @@ class RexLR(LRScheduler):
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
initial_lr = group["lr"]
|
||||
if isinstance(initial_lr, Tensor):
|
||||
initial_lr = initial_lr.clone()
|
||||
group.setdefault("initial_lr", initial_lr)
|
||||
# Pass self.last_step as last_epoch to the parent.
|
||||
super().__init__(optimizer, last_epoch=self.last_step)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user