Add seq2seq eval benchmark callback (#1274)
* Add CausalLMBenchEvalCallback for measuring seq2seq performance * Fix code for pre-commit * Fix typing and improve logging * eval_sample_packing must be false with CausalLMBenchEvalCallback
This commit is contained in:
@@ -784,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
|
|||||||
max_steps:
|
max_steps:
|
||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
|
||||||
|
|
||||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ s2_attention:
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ s2_attention:
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ flash_attention:
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ flash_attention: true
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ flash_attention: true
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed: deepspeed_configs/zero2.json
|
deepspeed: deepspeed_configs/zero2.json
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ loss_watchdog_patience: 3
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ flash_attention:
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ flash_attention:
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ num_epochs: 1
|
|||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
evals_per_epoch: 5
|
evals_per_epoch: 5
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_max_new_tokens: 128
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
eval_batch_size: 1
|
eval_batch_size: 1
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ numba
|
|||||||
numpy>=1.24.4
|
numpy>=1.24.4
|
||||||
mlflow
|
mlflow
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.0
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
@@ -148,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
do_bench_eval: Optional[bool] = field(
|
do_bench_eval: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
)
|
)
|
||||||
|
do_causal_lm_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||||
|
)
|
||||||
max_bench_samples: Optional[int] = field(
|
max_bench_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -664,6 +668,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||||
|
if self.cfg.do_causal_lm_eval:
|
||||||
|
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
|
||||||
|
trainer, self.tokenizer
|
||||||
|
)
|
||||||
|
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
|
||||||
|
|
||||||
if self.cfg.early_stopping_patience:
|
if self.cfg.early_stopping_patience:
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
@@ -812,6 +821,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||||
if self.cfg.bench_dataset:
|
if self.cfg.bench_dataset:
|
||||||
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||||
|
if self.cfg.do_causal_lm_eval:
|
||||||
|
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||||
if self.cfg.metric_for_best_model:
|
if self.cfg.metric_for_best_model:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"metric_for_best_model"
|
"metric_for_best_model"
|
||||||
|
|||||||
@@ -361,6 +361,187 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|
||||||
|
|
||||||
|
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||||
|
class CausalLMBenchEvalCallback(TrainerCallback):
|
||||||
|
"""Callback to log prediction values during each evaluation"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.logged = False
|
||||||
|
self.metrics = self.__maybe_load_metrics()
|
||||||
|
|
||||||
|
def __maybe_load_metrics(self):
|
||||||
|
metrics = {}
|
||||||
|
for metric in self.cfg.eval_causal_lm_metrics:
|
||||||
|
try:
|
||||||
|
metrics[metric] = evaluate.load(metric)
|
||||||
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.warning(f"{metric}: {exc.args}")
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
train_dataloader, # pylint: disable=unused-argument
|
||||||
|
eval_dataloader,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
trainer.model.eval()
|
||||||
|
device = torch.device(self.cfg.device)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=self.cfg.eval_max_new_tokens,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
do_sample=False,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_scores=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_ranges(lst):
|
||||||
|
ranges = []
|
||||||
|
start = 0
|
||||||
|
for i in range(1, len(lst)):
|
||||||
|
if lst[i] == 0:
|
||||||
|
ranges.append((start, i - 1))
|
||||||
|
start = i
|
||||||
|
end = len(lst) - 1
|
||||||
|
ranges.append((start, end))
|
||||||
|
return ranges
|
||||||
|
|
||||||
|
def compute(metric: evaluate.Metric, **kwargs):
|
||||||
|
# safely compute a metric and return the score if the format is correct
|
||||||
|
metric_score = None
|
||||||
|
try:
|
||||||
|
metric_score = metric.compute(**kwargs)
|
||||||
|
return (
|
||||||
|
metric_score["score"]
|
||||||
|
if "score" in metric_score
|
||||||
|
else metric_score["mean_score"]
|
||||||
|
)
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.debug(
|
||||||
|
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
|
||||||
|
)
|
||||||
|
return metric_score
|
||||||
|
|
||||||
|
def evaluate_preds(sources, predictions, references):
|
||||||
|
scores = {}
|
||||||
|
|
||||||
|
for metric_name, metric in self.metrics.items():
|
||||||
|
score = compute(
|
||||||
|
metric,
|
||||||
|
references=references,
|
||||||
|
predictions=predictions,
|
||||||
|
sources=sources,
|
||||||
|
)
|
||||||
|
score = score or compute(
|
||||||
|
metric,
|
||||||
|
references=[[r] for r in references],
|
||||||
|
predictions=predictions,
|
||||||
|
)
|
||||||
|
scores[metric_name] = score
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def predict_with_generate():
|
||||||
|
eval_src, eval_pred, eval_ref = [], [], []
|
||||||
|
|
||||||
|
for batch in tqdm(eval_dataloader):
|
||||||
|
batch_labels = batch["labels"].to(device)
|
||||||
|
batch_input_ids = batch["input_ids"].to(device)
|
||||||
|
|
||||||
|
if "position_ids" in batch:
|
||||||
|
batch_pos_ids = batch["position_ids"].tolist()
|
||||||
|
else:
|
||||||
|
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||||
|
|
||||||
|
prompt_token_ids_list = []
|
||||||
|
completion_token_ids_list = []
|
||||||
|
|
||||||
|
for input_ids_all, labels_all, pos_ids in zip(
|
||||||
|
batch_input_ids,
|
||||||
|
batch_labels,
|
||||||
|
batch_pos_ids,
|
||||||
|
):
|
||||||
|
if pos_ids is None:
|
||||||
|
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||||
|
else:
|
||||||
|
pos_ranges = find_ranges(pos_ids)
|
||||||
|
|
||||||
|
for pos_range in pos_ranges:
|
||||||
|
start, end = pos_range
|
||||||
|
if start == end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids = input_ids_all[start : end + 1]
|
||||||
|
labels = labels_all[start : end + 1]
|
||||||
|
|
||||||
|
tokens_without_loss = labels == IGNORE_INDEX
|
||||||
|
tokens_with_loss = labels != IGNORE_INDEX
|
||||||
|
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
||||||
|
prompt_token_includes = (
|
||||||
|
tokens_without_loss & tokens_exclude_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_token_ids = input_ids[prompt_token_includes]
|
||||||
|
prompt_token_ids_list.append(prompt_token_ids)
|
||||||
|
|
||||||
|
completion_token_ids = input_ids[tokens_with_loss]
|
||||||
|
completion_token_ids_list.append(completion_token_ids)
|
||||||
|
|
||||||
|
prompt_texts = tokenizer.batch_decode(
|
||||||
|
prompt_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
completion_texts = tokenizer.batch_decode(
|
||||||
|
completion_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_encoding = tokenizer(
|
||||||
|
prompt_texts, padding=True, return_tensors="pt"
|
||||||
|
).to(self.cfg.device)
|
||||||
|
predictions = trainer.model.generate(
|
||||||
|
**prompt_encoding, generation_config=generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||||
|
prediction_without_prompt_tokens_list = []
|
||||||
|
for prompt_token_ids, prediction_tokens in zip(
|
||||||
|
prompt_token_ids_list, prediction_all_tokens
|
||||||
|
):
|
||||||
|
prediction_without_prompt_tokens = prediction_tokens[
|
||||||
|
len(prompt_token_ids) :
|
||||||
|
]
|
||||||
|
prediction_without_prompt_tokens_list.append(
|
||||||
|
prediction_without_prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_texts = tokenizer.batch_decode(
|
||||||
|
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_src.extend(prompt_texts)
|
||||||
|
eval_pred.extend(predicted_texts)
|
||||||
|
eval_ref.extend(completion_texts)
|
||||||
|
|
||||||
|
return eval_src, eval_pred, eval_ref
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
eval_preds = predict_with_generate()
|
||||||
|
trainer.log(evaluate_preds(*eval_preds))
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
return CausalLMBenchEvalCallback
|
||||||
|
|
||||||
|
|
||||||
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||||
class LogPredictionCallback(TrainerCallback):
|
class LogPredictionCallback(TrainerCallback):
|
||||||
"""Callback to log prediction values during each evaluation"""
|
"""Callback to log prediction values during each evaluation"""
|
||||||
@@ -388,7 +569,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
max_new_tokens=self.cfg.eval_table_max_new_tokens,
|
max_new_tokens=self.cfg.eval_max_new_tokens,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
|||||||
@@ -56,7 +56,13 @@ def normalize_config(cfg):
|
|||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||||
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
|
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
||||||
|
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
||||||
|
"sacrebleu",
|
||||||
|
"comet",
|
||||||
|
"ter",
|
||||||
|
"chrf",
|
||||||
|
]
|
||||||
choose_device(cfg)
|
choose_device(cfg)
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
if cfg.ddp:
|
if cfg.ddp:
|
||||||
@@ -550,6 +556,21 @@ def validate_config(cfg):
|
|||||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||||
|
|
||||||
|
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
||||||
|
raise ValueError(
|
||||||
|
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.eval_causal_lm_metrics:
|
||||||
|
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
||||||
|
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||||
|
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||||
|
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||||
|
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
|
||||||
|
raise ValueError(
|
||||||
|
f"eval_causal_lm_metrics must be one of {supported_metrics}"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
Reference in New Issue
Block a user