diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index e66f165a5..2965ac1e2 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): class SaveModelOnTrainEndCallback(TrainerCallback): """Callback to save model on train end""" + def on_step_end( # pylint: disable=unused-argument + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Save + if state.global_step >= state.max_steps: + control.should_save = True + def on_train_end( # pylint: disable=unused-argument self, args, state, control, **kwargs ):