fix: deepcopy lr in RexLR scheduler. (#3012)

* fix: deepcopy lr in RexLR scheduler.

This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate.

See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/

* fix: add missing torch.Tensor import
This commit is contained in:
Carsten Kragelund Jørgensen
2025-08-04 16:23:49 +02:00
committed by GitHub
parent a54c1be972
commit 33d094721c

View File

@@ -4,6 +4,7 @@ import math
from functools import partial from functools import partial
from typing import Sequence from typing import Sequence
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -45,8 +46,10 @@ class RexLR(LRScheduler):
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups: for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"]) initial_lr = group["lr"]
if isinstance(initial_lr, Tensor):
initial_lr = initial_lr.clone()
group.setdefault("initial_lr", initial_lr)
# Pass self.last_step as last_epoch to the parent. # Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step) super().__init__(optimizer, last_epoch=self.last_step)