From a21b9cc472f9d7736db72e7f5dc4d8df962c1c28 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 30 Apr 2025 03:31:57 -0400 Subject: [PATCH] patch to convert LR from tensor to float when using DS (#2595) [skip ci] --- src/axolotl/core/trainer_builder.py | 3 ++ src/axolotl/monkeypatch/trainer/lr.py | 42 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 src/axolotl/monkeypatch/trainer/lr.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 970b02075..358058f69 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -60,6 +60,7 @@ from axolotl.core.training_args import ( from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback +from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( @@ -114,6 +115,8 @@ class TrainerBuilderBase(abc.ABC): if hasattr(model, "add_model_tags"): model.add_model_tags(["axolotl"]) + patch_trainer_get_lr() + @property def model_ref(self): return self._model_ref diff --git a/src/axolotl/monkeypatch/trainer/lr.py b/src/axolotl/monkeypatch/trainer/lr.py new file mode 100644 index 000000000..0176093d6 --- /dev/null +++ b/src/axolotl/monkeypatch/trainer/lr.py @@ -0,0 +1,42 @@ +""" +monkeypatch for Trainer _get_learning_rate method +""" + +import logging + +import torch + +LOG = logging.getLogger(__name__) + + +# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release +def _get_learning_rate(self): + if self.is_deepspeed_enabled: + # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may + # not run for the first few dozen steps while loss scale is too large, and thus during + # that time `get_last_lr` will fail if called during that warm up stage, so work around it: + try: + last_lr = self.lr_scheduler.get_last_lr()[0] + except AssertionError as e: + if "need to call step" in str(e): + LOG.warning( + "tried to get lr value before scheduler/optimizer started stepping, returning lr=0" + ) + last_lr = 0 + else: + raise + else: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + last_lr = self.optimizer.param_groups[0]["lr"] + else: + last_lr = self.lr_scheduler.get_last_lr()[0] + + if torch.is_tensor(last_lr): + last_lr = last_lr.item() + return last_lr + + +def patch_trainer_get_lr(): + from transformers.trainer import Trainer + + Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access