fixes to save on fractional save_steps (#1643)
This commit is contained in:
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SaveModelOnTrainEndCallback,
|
SaveModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -945,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
callbacks.append(SaveModelOnTrainEndCallback())
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1431,7 +1431,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(SaveModelOnTrainEndCallback())
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
class SaveModelCallback(TrainerCallback):
|
||||||
"""Callback to save model on train end"""
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
def on_step_end( # pylint: disable=unused-argument
|
def on_step_end( # pylint: disable=unused-argument
|
||||||
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
|
|||||||
# Save
|
# Save
|
||||||
if state.global_step >= state.max_steps:
|
if state.global_step >= state.max_steps:
|
||||||
control.should_save = True
|
control.should_save = True
|
||||||
|
elif (
|
||||||
|
args.save_strategy == IntervalStrategy.STEPS
|
||||||
|
and state.save_steps < 1.0
|
||||||
|
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
|
||||||
|
):
|
||||||
|
# workaround to save model on fractional save_steps
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
def on_train_end( # pylint: disable=unused-argument
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user