fix: deepcopy lr in RexLR scheduler. (#3012)

* fix: deepcopy lr in RexLR scheduler.

This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate.

See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/

* fix: add missing torch.Tensor import
This commit is contained in:
Carsten Kragelund Jørgensen
2025-08-04 16:23:49 +02:00
committed by GitHub
parent a54c1be972
commit 33d094721c

View File

@@ -4,6 +4,7 @@ import math
from functools import partial
from typing import Sequence
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -45,8 +46,10 @@ class RexLR(LRScheduler):
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
initial_lr = group["lr"]
if isinstance(initial_lr, Tensor):
initial_lr = initial_lr.clone()
group.setdefault("initial_lr", initial_lr)
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)