Early stopping metric (#537)
* set early stopping metric to check * tweak how load_best_model_at_end gets set for early stopping * add validation for earl;y stopping patience * remove negation * save results to metrics in callback * move early stopping callback after the benchmark evals * broadcast metrics so early stopping works
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
|
broadcast_dict,
|
||||||
gather_scalar_from_all_ranks,
|
gather_scalar_from_all_ranks,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
is_distributed,
|
is_distributed,
|
||||||
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
lambda: len(data_loader), get_world_size()
|
lambda: len(data_loader), get_world_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
results = {}
|
||||||
if is_distributed() and not is_main_process():
|
if is_distributed() and not is_main_process():
|
||||||
dist.gather_object(local_bench_names, dst=0)
|
dist.gather_object(local_bench_names, dst=0)
|
||||||
else:
|
else:
|
||||||
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
)["accuracy"]
|
)["accuracy"]
|
||||||
trainer.log(results)
|
trainer.log(results)
|
||||||
|
|
||||||
|
results = broadcast_dict(results)
|
||||||
|
for key, val in results.items():
|
||||||
|
metrics[key] = val
|
||||||
|
|
||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -220,6 +220,15 @@ def validate_config(cfg):
|
|||||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.early_stopping_patience:
|
||||||
|
if not cfg.save_steps or not cfg.eval_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
|
if cfg.save_steps % cfg.eval_steps != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import pickle # nosec
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|||||||
gathered_values.append(float(tensor.item()))
|
gathered_values.append(float(tensor.item()))
|
||||||
return gathered_values
|
return gathered_values
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_dict(vals: dict):
|
||||||
|
if not is_distributed():
|
||||||
|
return vals
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
data_byte = pickle.dumps(vals)
|
||||||
|
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
||||||
|
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
||||||
|
else:
|
||||||
|
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
||||||
|
data_size = torch.IntTensor([0]).to("cuda")
|
||||||
|
|
||||||
|
dist.broadcast(data_size, 0)
|
||||||
|
if not is_main_process():
|
||||||
|
# resize
|
||||||
|
data_tensor = data_tensor.new_empty([data_size.item()])
|
||||||
|
|
||||||
|
dist.broadcast(data_tensor, 0)
|
||||||
|
|
||||||
|
if not is_main_process():
|
||||||
|
data_list = data_tensor.cpu().tolist()
|
||||||
|
data_byte = bytes(data_list[: data_size.item()])
|
||||||
|
vals = pickle.loads(data_byte) # nosec
|
||||||
|
|
||||||
|
return vals
|
||||||
|
|||||||
@@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||||
if cfg.bench_dataset:
|
if cfg.bench_dataset:
|
||||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||||
|
if cfg.metric_for_best_model:
|
||||||
|
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
||||||
|
if cfg.greater_is_better:
|
||||||
|
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if cfg.ddp_timeout:
|
if cfg.ddp_timeout:
|
||||||
@@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||||
load_best_model_at_end=(
|
load_best_model_at_end=(
|
||||||
cfg.load_best_model_at_end is not False
|
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
||||||
and cfg.val_set_size > 0
|
and cfg.val_set_size > 0
|
||||||
and cfg.save_steps
|
and cfg.save_steps
|
||||||
and cfg.save_steps % cfg.eval_steps == 0
|
and cfg.save_steps % cfg.eval_steps == 0
|
||||||
and cfg.load_in_8bit is not True
|
|
||||||
)
|
)
|
||||||
or False,
|
or False,
|
||||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||||
@@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
callbacks.append(ReLoRACallback(cfg))
|
callbacks.append(ReLoRACallback(cfg))
|
||||||
|
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
|
||||||
if cfg.early_stopping_patience:
|
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
|
||||||
cfg.early_stopping_patience,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stop_cb)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0 and cfg.adapter in [
|
if cfg.local_rank == 0 and cfg.adapter in [
|
||||||
"lora",
|
"lora",
|
||||||
"qlora",
|
"qlora",
|
||||||
@@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
if cfg.do_bench_eval:
|
if cfg.do_bench_eval:
|
||||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
|
if cfg.early_stopping_patience:
|
||||||
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
cfg.early_stopping_patience,
|
||||||
|
)
|
||||||
|
trainer.add_callback(early_stop_cb)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user