feature: loss watchdog for terminating training runs that are failing (#899)

Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
This commit is contained in:
kallewoof
2023-12-04 21:54:34 +09:00
committed by GitHub
parent 476a205cea
commit 58ec8b1113
4 changed files with 40 additions and 0 deletions

View File

@@ -694,6 +694,9 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
save_safetensors:

View File

@@ -62,6 +62,9 @@ logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
eval_steps: 0.05
eval_table_size:

View File

@@ -25,6 +25,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
@@ -430,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -124,6 +124,36 @@ class GPUStatsCallback(
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [