patch to convert LR from tensor to float when using DS (#2595) [skip ci]
This commit is contained in:
@@ -60,6 +60,7 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
|
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||||
from axolotl.processing_strategies import get_processing_strategy
|
from axolotl.processing_strategies import get_processing_strategy
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
@@ -114,6 +115,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
|
patch_trainer_get_lr()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_ref(self):
|
def model_ref(self):
|
||||||
return self._model_ref
|
return self._model_ref
|
||||||
|
|||||||
42
src/axolotl/monkeypatch/trainer/lr.py
Normal file
42
src/axolotl/monkeypatch/trainer/lr.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
monkeypatch for Trainer _get_learning_rate method
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
|
||||||
|
def _get_learning_rate(self):
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||||
|
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||||
|
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
||||||
|
try:
|
||||||
|
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
|
except AssertionError as e:
|
||||||
|
if "need to call step" in str(e):
|
||||||
|
LOG.warning(
|
||||||
|
"tried to get lr value before scheduler/optimizer started stepping, returning lr=0"
|
||||||
|
)
|
||||||
|
last_lr = 0
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
last_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
|
else:
|
||||||
|
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
|
|
||||||
|
if torch.is_tensor(last_lr):
|
||||||
|
last_lr = last_lr.item()
|
||||||
|
return last_lr
|
||||||
|
|
||||||
|
|
||||||
|
def patch_trainer_get_lr():
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access
|
||||||
Reference in New Issue
Block a user