fixes to save on fractional save_steps (#1643)
This commit is contained in:
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SaveModelOnTrainEndCallback,
|
||||
SaveModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
callbacks.append(SaveModelOnTrainEndCallback())
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(SaveModelOnTrainEndCallback())
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||
class SaveModelCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_step_end( # pylint: disable=unused-argument
|
||||
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||
# Save
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_save = True
|
||||
elif (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and state.save_steps < 1.0
|
||||
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
|
||||
):
|
||||
# workaround to save model on fractional save_steps
|
||||
control.should_save = True
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user