diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index b9d09ad9c..cdaf92271 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -4,6 +4,7 @@ import math from functools import partial from typing import Sequence +from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @@ -45,8 +46,10 @@ class RexLR(LRScheduler): # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) - + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) # Pass self.last_step as last_epoch to the parent. super().__init__(optimizer, last_epoch=self.last_step)