improve save callbacks (#1592)
This commit is contained in:
@@ -43,6 +43,7 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
|
SaveModelOnTrainEndCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -888,6 +889,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -933,18 +942,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
callbacks.append(SaveBetterTransformerModelCallback())
|
callbacks.append(SaveBetterTransformerModelCallback())
|
||||||
|
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1427,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
|||||||
@@ -773,3 +773,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||||
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
):
|
||||||
|
control.should_save = True
|
||||||
|
return control
|
||||||
|
|||||||
Reference in New Issue
Block a user