fixes to save on fractional save_steps (#1643)

This commit is contained in:
Wing Lian
2024-05-20 14:24:45 -04:00
committed by GitHub
parent 8a1572a831
commit ba45531802
2 changed files with 12 additions and 4 deletions

View File

@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback, SaveModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None: if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg)) callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelOnTrainEndCallback()) callbacks.append(SaveModelCallback())
return callbacks return callbacks
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback()) callbacks.append(SaveModelCallback())
return callbacks return callbacks

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import math
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control return control
class SaveModelOnTrainEndCallback(TrainerCallback): class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end""" """Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument def on_step_end( # pylint: disable=unused-argument
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
# Save # Save
if state.global_step >= state.max_steps: if state.global_step >= state.max_steps:
control.should_save = True control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs self, args, state, control, **kwargs