handle when unable to save optimizer state when using ao optimizer with FSDP (#2773) [skip ci]
* handle when unable to save optimizer state when using ao optimizer with FSDP1 * improve messaging Co-authored-by: salman <salman.mohammadi@outlook.com> --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
|
CheckpointSaveMixin,
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
@@ -39,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
class AxolotlTrainer(
|
||||||
|
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||||
|
):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .checkpoints import CheckpointSaveMixin
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from .rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from .scheduler import SchedulerMixin
|
||||||
|
|||||||
21
src/axolotl/core/trainers/mixins/checkpoints.py
Normal file
21
src/axolotl/core/trainers/mixins/checkpoints.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Custom handling to not fail training if fsdp optimizer is not savable"""
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointSaveMixin(Trainer):
|
||||||
|
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
|
||||||
|
|
||||||
|
def _save_optimizer_and_scheduler(self, output_dir):
|
||||||
|
try:
|
||||||
|
super()._save_optimizer_and_scheduler(output_dir)
|
||||||
|
except NotImplementedError as exc:
|
||||||
|
LOG.warning(
|
||||||
|
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||||
|
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||||
|
"for this training run will not be possible."
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user