patch to convert LR from tensor to float when using DS (#2595) [skip ci]

This commit is contained in:
Wing Lian
2025-04-30 03:31:57 -04:00
parent 097e7e3b5b
commit ee00142cb5
2 changed files with 45 additions and 0 deletions

View File

@@ -60,6 +60,7 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
@@ -114,6 +115,8 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref

View File

@@ -0,0 +1,42 @@
"""
monkeypatch for Trainer _get_learning_rate method
"""
import logging
import torch
LOG = logging.getLogger(__name__)
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
def _get_learning_rate(self):
if self.is_deepspeed_enabled:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
LOG.warning(
"tried to get lr value before scheduler/optimizer started stepping, returning lr=0"
)
last_lr = 0
else:
raise
else:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
last_lr = self.optimizer.param_groups[0]["lr"]
else:
last_lr = self.lr_scheduler.get_last_lr()[0]
if torch.is_tensor(last_lr):
last_lr = last_lr.item()
return last_lr
def patch_trainer_get_lr():
from transformers.trainer import Trainer
Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access