fixes to save on fractional save_steps (#1643)

This commit is contained in:
Wing Lian
2024-05-20 14:24:45 -04:00
committed by GitHub
parent 8a1572a831
commit ba45531802
2 changed files with 12 additions and 4 deletions

View File

@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback,
SaveModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelOnTrainEndCallback())
callbacks.append(SaveModelCallback())
return callbacks
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback())
callbacks.append(SaveModelCallback())
return callbacks

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import math
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control
class SaveModelOnTrainEndCallback(TrainerCallback):
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs