fix: deepcopy lr in RexLR scheduler. (#3012)
* fix: deepcopy lr in RexLR scheduler. This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate. See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/ * fix: add missing torch.Tensor import
This commit is contained in:
committed by
GitHub
parent
a54c1be972
commit
33d094721c
@@ -4,6 +4,7 @@ import math
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
@@ -45,8 +46,10 @@ class RexLR(LRScheduler):
|
|||||||
|
|
||||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group.setdefault("initial_lr", group["lr"])
|
initial_lr = group["lr"]
|
||||||
|
if isinstance(initial_lr, Tensor):
|
||||||
|
initial_lr = initial_lr.clone()
|
||||||
|
group.setdefault("initial_lr", initial_lr)
|
||||||
# Pass self.last_step as last_epoch to the parent.
|
# Pass self.last_step as last_epoch to the parent.
|
||||||
super().__init__(optimizer, last_epoch=self.last_step)
|
super().__init__(optimizer, last_epoch=self.last_step)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user