From 33d094721c823fe98a5e387c9d61d9815607fcc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carsten=20Kragelund=20J=C3=B8rgensen?= Date: Mon, 4 Aug 2025 16:23:49 +0200 Subject: [PATCH] fix: deepcopy lr in RexLR scheduler. (#3012) * fix: deepcopy lr in RexLR scheduler. This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate. See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/ * fix: add missing torch.Tensor import --- src/axolotl/utils/schedulers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index b9d09ad9c..cdaf92271 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -4,6 +4,7 @@ import math from functools import partial from typing import Sequence +from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @@ -45,8 +46,10 @@ class RexLR(LRScheduler): # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) - + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) # Pass self.last_step as last_epoch to the parent. super().__init__(optimizer, last_epoch=self.last_step)