From 58ec8b1113e97bc608a0b943e66734eeb847015f Mon Sep 17 00:00:00 2001 From: kallewoof Date: Mon, 4 Dec 2023 21:54:34 +0900 Subject: [PATCH] feature: loss watchdog for terminating training runs that are failing (#899) Co-authored-by: Karl-Johan Alm --- README.md | 3 +++ examples/mistral/qlora.yml | 3 +++ src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/callbacks.py | 30 +++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+) diff --git a/README.md b/README.md index 093b8210d..ebe650c29 100644 --- a/README.md +++ b/README.md @@ -694,6 +694,9 @@ max_steps: eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 +loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) +loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) + # Save model as safetensors (require safetensors package) save_safetensors: diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 28c5ed242..b6dd46a55 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -62,6 +62,9 @@ logging_steps: 1 xformers_attention: flash_attention: true +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + warmup_steps: 10 eval_steps: 0.05 eval_table_size: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 62e527beb..b4c22e548 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -25,6 +25,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, + LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, @@ -430,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) + if self.cfg.loss_watchdog_threshold is not None: + callbacks.append(LossWatchDogCallback(self.cfg)) + return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 4191bcf16..8599c0df0 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -124,6 +124,36 @@ class GPUStatsCallback( return control +class LossWatchDogCallback(TrainerCallback): + """Callback to track loss and stop training if loss is too high""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + self.violations = 0 + self.threshold = cfg.loss_watchdog_threshold + self.patience = cfg.loss_watchdog_patience or 3 + + def on_step_end( + self, + _args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ): + if len(state.log_history) > 0 and "loss" in state.log_history[-1]: + if state.log_history[-1]["loss"] > self.threshold: + self.violations += 1 + if self.violations >= self.patience: + LOG.warning( + "Loss is too high, stopping training (loss_watchdog_threshold)" + ) + control.should_training_stop = True + else: + self.violations = 0 + return control + + def bench_eval_callback_factory(trainer, tokenizer): accuracy = evaluate.load("accuracy") abcd_idx = [