From e30f1e3cf7bfa8d5e7bf50a305e0f5c67fbf7b4c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Sep 2023 11:57:02 -0400 Subject: [PATCH] Early stopping metric (#537) * set early stopping metric to check * tweak how load_best_model_at_end gets set for early stopping * add validation for earl;y stopping patience * remove negation * save results to metrics in callback * move early stopping callback after the benchmark evals * broadcast metrics so early stopping works --- src/axolotl/utils/callbacks.py | 6 ++++++ src/axolotl/utils/config.py | 9 +++++++++ src/axolotl/utils/distributed.py | 28 ++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 21 ++++++++++++--------- 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 8fc5a918b..3f776537a 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( barrier, + broadcast_dict, gather_scalar_from_all_ranks, get_world_size, is_distributed, @@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer): lambda: len(data_loader), get_world_size() ) + results = {} if is_distributed() and not is_main_process(): dist.gather_object(local_bench_names, dst=0) else: @@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer): )["accuracy"] trainer.log(results) + results = broadcast_dict(results) + for key, val in results.items(): + metrics[key] = val + return BenchEvalCallback diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 7fc6e1232..6de807eab 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -220,6 +220,15 @@ def validate_config(cfg): "sample_packing not compatible with xformers_attention. Use flash_attention" ) + if cfg.early_stopping_patience: + if not cfg.save_steps or not cfg.eval_steps: + raise ValueError( + "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." + ) + if cfg.save_steps % cfg.eval_steps != 0: + raise ValueError( + "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." + ) # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 5e527f3b9..d48659db1 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -2,6 +2,7 @@ utility helpers for distributed checks """ import os +import pickle # nosec from contextlib import contextmanager import torch @@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n gathered_values.append(float(tensor.item())) return gathered_values return None + + +def broadcast_dict(vals: dict): + if not is_distributed(): + return vals + + if is_main_process(): + data_byte = pickle.dumps(vals) + data_tensor = torch.ByteTensor(list(data_byte)).to("cuda") + data_size = torch.IntTensor([len(data_byte)]).to("cuda") + else: + data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda") + data_size = torch.IntTensor([0]).to("cuda") + + dist.broadcast(data_size, 0) + if not is_main_process(): + # resize + data_tensor = data_tensor.new_empty([data_size.item()]) + + dist.broadcast(data_tensor, 0) + + if not is_main_process(): + data_list = data_tensor.cpu().tolist() + data_byte = bytes(data_list[: data_size.item()]) + vals = pickle.loads(data_byte) # nosec + + return vals diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3bc283d75..ece1bd9b6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval if cfg.bench_dataset: training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset + if cfg.metric_for_best_model: + training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model + if cfg.greater_is_better: + training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better # DDP Config if cfg.ddp_timeout: @@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ output_dir=cfg.output_dir, save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4, load_best_model_at_end=( - cfg.load_best_model_at_end is not False + (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience) and cfg.val_set_size > 0 and cfg.save_steps and cfg.save_steps % cfg.eval_steps == 0 - and cfg.load_in_8bit is not True ) or False, ddp_find_unused_parameters=False if cfg.ddp else None, @@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ if cfg.relora_steps: callbacks.append(ReLoRACallback(cfg)) - # TODO on_save callback to sync checkpoints to GCP/AWS in background - if cfg.early_stopping_patience: - early_stop_cb = EarlyStoppingCallback( - cfg.early_stopping_patience, - ) - callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter in [ "lora", "qlora", @@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) + # TODO on_save callback to sync checkpoints to GCP/AWS in background + if cfg.early_stopping_patience: + early_stop_cb = EarlyStoppingCallback( + cfg.early_stopping_patience, + ) + trainer.add_callback(early_stop_cb) + return trainer