Early stopping metric (#537)

* set early stopping metric to check

* tweak how load_best_model_at_end gets set for early stopping

* add validation for earl;y stopping patience

* remove negation

* save results to metrics in callback

* move early stopping callback after the benchmark evals

* broadcast metrics so early stopping works
This commit is contained in:
Wing Lian
2023-09-08 11:57:02 -04:00
committed by GitHub
parent 343714972b
commit e30f1e3cf7
4 changed files with 55 additions and 9 deletions

View File

@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
gather_scalar_from_all_ranks,
get_world_size,
is_distributed,
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
lambda: len(data_loader), get_world_size()
)
results = {}
if is_distributed() and not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
)["accuracy"]
trainer.log(results)
results = broadcast_dict(results)
for key, val in results.items():
metrics[key] = val
return BenchEvalCallback

View File

@@ -220,6 +220,15 @@ def validate_config(cfg):
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if cfg.save_steps % cfg.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -2,6 +2,7 @@
utility helpers for distributed checks
"""
import os
import pickle # nosec
from contextlib import contextmanager
import torch
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
gathered_values.append(float(tensor.item()))
return gathered_values
return None
def broadcast_dict(vals: dict):
if not is_distributed():
return vals
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
dist.broadcast(data_size, 0)
if not is_main_process():
# resize
data_tensor = data_tensor.new_empty([data_size.item()])
dist.broadcast(data_tensor, 0)
if not is_main_process():
data_list = data_tensor.cpu().tolist()
data_byte = bytes(data_list[: data_size.item()])
vals = pickle.loads(data_byte) # nosec
return vals

View File

@@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
if cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
# DDP Config
if cfg.ddp_timeout:
@@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
@@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
@@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)
return trainer