Early stopping metric (#537)
* set early stopping metric to check * tweak how load_best_model_at_end gets set for early stopping * add validation for earl;y stopping patience * remove negation * save results to metrics in callback * move early stopping callback after the benchmark evals * broadcast metrics so early stopping works
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
broadcast_dict,
|
||||
gather_scalar_from_all_ranks,
|
||||
get_world_size,
|
||||
is_distributed,
|
||||
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
lambda: len(data_loader), get_world_size()
|
||||
)
|
||||
|
||||
results = {}
|
||||
if is_distributed() and not is_main_process():
|
||||
dist.gather_object(local_bench_names, dst=0)
|
||||
else:
|
||||
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
)["accuracy"]
|
||||
trainer.log(results)
|
||||
|
||||
results = broadcast_dict(results)
|
||||
for key, val in results.items():
|
||||
metrics[key] = val
|
||||
|
||||
return BenchEvalCallback
|
||||
|
||||
@@ -220,6 +220,15 @@ def validate_config(cfg):
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||
)
|
||||
if cfg.save_steps % cfg.eval_steps != 0:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
import os
|
||||
import pickle # nosec
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
||||
gathered_values.append(float(tensor.item()))
|
||||
return gathered_values
|
||||
return None
|
||||
|
||||
|
||||
def broadcast_dict(vals: dict):
|
||||
if not is_distributed():
|
||||
return vals
|
||||
|
||||
if is_main_process():
|
||||
data_byte = pickle.dumps(vals)
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
||||
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
||||
else:
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
||||
data_size = torch.IntTensor([0]).to("cuda")
|
||||
|
||||
dist.broadcast(data_size, 0)
|
||||
if not is_main_process():
|
||||
# resize
|
||||
data_tensor = data_tensor.new_empty([data_size.item()])
|
||||
|
||||
dist.broadcast(data_tensor, 0)
|
||||
|
||||
if not is_main_process():
|
||||
data_list = data_tensor.cpu().tolist()
|
||||
data_byte = bytes(data_list[: data_size.item()])
|
||||
vals = pickle.loads(data_byte) # nosec
|
||||
|
||||
return vals
|
||||
|
||||
@@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||
if cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||
if cfg.metric_for_best_model:
|
||||
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
||||
if cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||
|
||||
# DDP Config
|
||||
if cfg.ddp_timeout:
|
||||
@@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
load_best_model_at_end=(
|
||||
cfg.load_best_model_at_end is not False
|
||||
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
||||
and cfg.val_set_size > 0
|
||||
and cfg.save_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
)
|
||||
or False,
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
@@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
cfg.early_stopping_patience,
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
if cfg.local_rank == 0 and cfg.adapter in [
|
||||
"lora",
|
||||
"qlora",
|
||||
@@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.do_bench_eval:
|
||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
cfg.early_stopping_patience,
|
||||
)
|
||||
trainer.add_callback(early_stop_cb)
|
||||
|
||||
return trainer
|
||||
|
||||
Reference in New Issue
Block a user