add mmlu callback
This commit is contained in:
@@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git
|
|||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||||
addict
|
addict
|
||||||
|
evaluate
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets
|
datasets
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
"""Callbacks for Trainer class"""
|
"""Callbacks for Trainer class"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import evaluate
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainerControl,
|
TrainerControl,
|
||||||
@@ -14,7 +22,11 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||||
@@ -96,3 +108,89 @@ class GPUStatsCallback(
|
|||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||||
self.logged = True
|
self.logged = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||||
|
accuracy = evaluate.load("accuracy")
|
||||||
|
abcd_idx = [
|
||||||
|
tokenizer("A", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("B", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("C", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||||
|
]
|
||||||
|
mmlu_split = "eval"
|
||||||
|
if trainer.args.mmlu_dataset == "mmlu-zs":
|
||||||
|
mmlu_dataset = load_dataset(
|
||||||
|
"json",
|
||||||
|
data_files={
|
||||||
|
"eval": "data/mmlu/zero_shot_mmlu_val.json",
|
||||||
|
"test": "data/mmlu/zero_shot_mmlu_test.json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
||||||
|
# MMLU Five-shot (Eval/Test only)
|
||||||
|
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
|
||||||
|
mmlu_dataset = load_dataset(
|
||||||
|
"json",
|
||||||
|
data_files={
|
||||||
|
"eval": "data/mmlu/five_shot_mmlu_val.json",
|
||||||
|
"test": "data/mmlu/five_shot_mmlu_test.json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
|
||||||
|
else:
|
||||||
|
raise ValueError("unhandled value for mmlu_dataset training args")
|
||||||
|
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
|
||||||
|
if trainer.args.max_mmlu_samples is not None:
|
||||||
|
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
||||||
|
|
||||||
|
class MMLUEvalCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
TrainerCallback that runs the MMLU evals
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: AxolotlTrainingArguments,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||||
|
source_max_len = trainer.data_collator.source_max_len
|
||||||
|
trainer.data_collator.source_max_len = args.mmlu_source_max_len
|
||||||
|
trainer.model.eval()
|
||||||
|
preds, refs = [], []
|
||||||
|
loss_mmlu = 0
|
||||||
|
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||||
|
(loss, logits, labels) = trainer.prediction_step(
|
||||||
|
trainer.model,
|
||||||
|
batch,
|
||||||
|
prediction_loss_only=False,
|
||||||
|
)
|
||||||
|
# There are two tokens, the output, and eos token.
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
|
||||||
|
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||||
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
|
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
||||||
|
loss_mmlu += loss.item()
|
||||||
|
# Extract results by subject.
|
||||||
|
results = {"mmlu_loss": loss_mmlu / len(data_loader)}
|
||||||
|
subject = mmlu_dataset["subject"]
|
||||||
|
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
||||||
|
for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name
|
||||||
|
subjects[s]["preds"].append(p)
|
||||||
|
subjects[s]["refs"].append(r)
|
||||||
|
subject_scores = []
|
||||||
|
for subject in subjects:
|
||||||
|
subject_score = accuracy.compute(
|
||||||
|
references=subjects[subject]["refs"],
|
||||||
|
predictions=subjects[subject]["preds"],
|
||||||
|
)["accuracy"]
|
||||||
|
results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score
|
||||||
|
subject_scores.append(subject_score)
|
||||||
|
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||||
|
trainer.log(results)
|
||||||
|
trainer.data_collator.source_max_len = source_max_len
|
||||||
|
|
||||||
|
return MMLUEvalCallback
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from axolotl.utils.callbacks import (
|
|||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
|
mmlu_eval_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
@@ -127,6 +128,27 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
)
|
)
|
||||||
|
mmlu_split: Optional[str] = field(
|
||||||
|
default="eval", metadata={"help": "The MMLU split to run on"}
|
||||||
|
)
|
||||||
|
mmlu_dataset: Optional[str] = field(
|
||||||
|
default="mmlu-fs",
|
||||||
|
metadata={
|
||||||
|
"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_mmlu_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the MMLU evaluation."}
|
||||||
|
)
|
||||||
|
max_mmlu_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mmlu_source_max_len: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -517,6 +539,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"steps" if cfg.save_steps else "epoch"
|
"steps" if cfg.save_steps else "epoch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.do_mmlu_eval:
|
||||||
|
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
@@ -631,4 +656,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.do_mmlu_eval:
|
||||||
|
trainer.add_callback(mmlu_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user