From dd660c2ed046e8715cdf73c23cf14066ac165ce7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Jun 2025 21:26:14 -0700 Subject: [PATCH] handle when unable to save optimizer state when using ao optimizer with FSDP (#2773) [skip ci] * handle when unable to save optimizer state when using ao optimizer with FSDP1 * improve messaging Co-authored-by: salman --------- Co-authored-by: salman --- src/axolotl/core/trainers/base.py | 5 ++++- src/axolotl/core/trainers/mixins/__init__.py | 1 + .../core/trainers/mixins/checkpoints.py | 21 +++++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/core/trainers/mixins/checkpoints.py diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 70e443cb3..d6f2c579a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length from typing_extensions import override from axolotl.core.trainers.mixins import ( + CheckpointSaveMixin, OptimizerMixin, RngLoaderMixin, SchedulerMixin, @@ -39,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) -class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): +class AxolotlTrainer( + SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer +): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index a71cb321a..178232077 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -3,6 +3,7 @@ # pylint: disable=unused-import # flake8: noqa +from .checkpoints import CheckpointSaveMixin from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin diff --git a/src/axolotl/core/trainers/mixins/checkpoints.py b/src/axolotl/core/trainers/mixins/checkpoints.py new file mode 100644 index 000000000..8f994d78c --- /dev/null +++ b/src/axolotl/core/trainers/mixins/checkpoints.py @@ -0,0 +1,21 @@ +"""Custom handling to not fail training if fsdp optimizer is not savable""" + +from transformers import Trainer + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class CheckpointSaveMixin(Trainer): + """Mixin to handle saving the optimizer and scheduler if they are not savable.""" + + def _save_optimizer_and_scheduler(self, output_dir): + try: + super()._save_optimizer_and_scheduler(output_dir) + except NotImplementedError as exc: + LOG.warning( + f"Trainer does not support saving optimizer and scheduler: {exc}\n" + "Optimizer and scheduler states were not saved - resuming from checkpoints " + "for this training run will not be possible." + )