handle when unable to save optimizer state when using ao optimizer with FSDP (#2773) [skip ci]
* handle when unable to save optimizer state when using ao optimizer with FSDP1 * improve messaging Co-authored-by: salman <salman.mohammadi@outlook.com> --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
@@ -39,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
|
||||
21
src/axolotl/core/trainers/mixins/checkpoints.py
Normal file
21
src/axolotl/core/trainers/mixins/checkpoints.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Custom handling to not fail training if fsdp optimizer is not savable"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CheckpointSaveMixin(Trainer):
|
||||
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
|
||||
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except NotImplementedError as exc:
|
||||
LOG.warning(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible."
|
||||
)
|
||||
Reference in New Issue
Block a user