add mmlu callback

This commit is contained in:
Wing Lian
2023-08-19 18:26:19 -04:00
parent 98bf76e236
commit 9aed60fa54
3 changed files with 127 additions and 0 deletions

View File

@@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict
evaluate
fire
PyYAML>=6.0
datasets

View File

@@ -1,9 +1,17 @@
"""Callbacks for Trainer class"""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
TrainerCallback,
TrainerControl,
@@ -14,7 +22,11 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -96,3 +108,89 @@ class GPUStatsCallback(
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control
def mmlu_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
]
mmlu_split = "eval"
if trainer.args.mmlu_dataset == "mmlu-zs":
mmlu_dataset = load_dataset(
"json",
data_files={
"eval": "data/mmlu/zero_shot_mmlu_val.json",
"test": "data/mmlu/zero_shot_mmlu_test.json",
},
)
mmlu_dataset = mmlu_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
mmlu_dataset = load_dataset(
"json",
data_files={
"eval": "data/mmlu/five_shot_mmlu_val.json",
"test": "data/mmlu/five_shot_mmlu_test.json",
},
)
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
else:
raise ValueError("unhandled value for mmlu_dataset training args")
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
if trainer.args.max_mmlu_samples is not None:
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
class MMLUEvalCallback(TrainerCallback):
"""
TrainerCallback that runs the MMLU evals
"""
def on_evaluate(
self,
args: AxolotlTrainingArguments,
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
source_max_len = trainer.data_collator.source_max_len
trainer.data_collator.source_max_len = args.mmlu_source_max_len
trainer.model.eval()
preds, refs = [], []
loss_mmlu = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
# There are two tokens, the output, and eos token.
for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [abcd_idx.index(label) for label in labels.tolist()]
loss_mmlu += loss.item()
# Extract results by subject.
results = {"mmlu_loss": loss_mmlu / len(data_loader)}
subject = mmlu_dataset["subject"]
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name
subjects[s]["preds"].append(p)
subjects[s]["refs"].append(r)
subject_scores = []
for subject in subjects:
subject_score = accuracy.compute(
references=subjects[subject]["refs"],
predictions=subjects[subject]["preds"],
)["accuracy"]
results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score
subject_scores.append(subject_score)
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
trainer.log(results)
trainer.data_collator.source_max_len = source_max_len
return MMLUEvalCallback

View File

@@ -23,6 +23,7 @@ from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
mmlu_eval_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -127,6 +128,27 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
mmlu_split: Optional[str] = field(
default="eval", metadata={"help": "The MMLU split to run on"}
)
mmlu_dataset: Optional[str] = field(
default="mmlu-fs",
metadata={
"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."
},
)
do_mmlu_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the MMLU evaluation."}
)
max_mmlu_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."
},
)
mmlu_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
)
class AxolotlTrainer(Trainer):
@@ -517,6 +539,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"steps" if cfg.save_steps else "epoch"
)
if cfg.do_mmlu_eval:
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
@@ -631,4 +656,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)
if cfg.do_mmlu_eval:
trainer.add_callback(mmlu_eval_callback_factory(trainer, tokenizer))
return trainer