diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7330a78ef..0fdd126f5 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -43,7 +43,7 @@ from axolotl.utils.callbacks import ( LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, - SaveModelOnTrainEndCallback, + SaveModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, @@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) - callbacks.append(SaveModelOnTrainEndCallback()) + callbacks.append(SaveModelCallback()) return callbacks @@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() - callbacks.append(SaveModelOnTrainEndCallback()) + callbacks.append(SaveModelCallback()) return callbacks diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 2965ac1e2..c21ef0ad7 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import math import os from shutil import copyfile from tempfile import NamedTemporaryFile @@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): return control -class SaveModelOnTrainEndCallback(TrainerCallback): +class SaveModelCallback(TrainerCallback): """Callback to save model on train end""" def on_step_end( # pylint: disable=unused-argument @@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback): # Save if state.global_step >= state.max_steps: control.should_save = True + elif ( + args.save_strategy == IntervalStrategy.STEPS + and state.save_steps < 1.0 + and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0 + ): + # workaround to save model on fractional save_steps + control.should_save = True def on_train_end( # pylint: disable=unused-argument self, args, state, control, **kwargs