From 9aed60fa545cab3b803524a038afde7f041bf51b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 19 Aug 2023 18:26:19 -0400 Subject: [PATCH] add mmlu callback --- requirements.txt | 1 + src/axolotl/utils/callbacks.py | 98 ++++++++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 28 ++++++++++ 3 files changed, 127 insertions(+) diff --git a/requirements.txt b/requirements.txt index 156d99b48..8dfe11e1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git bitsandbytes>=0.41.1 accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b addict +evaluate fire PyYAML>=6.0 datasets diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ddc179f39..6ffd207ca 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -1,9 +1,17 @@ """Callbacks for Trainer class""" +from __future__ import annotations + import logging import os +from typing import TYPE_CHECKING +import evaluate +import numpy as np +import torch +from datasets import load_dataset from optimum.bettertransformer import BetterTransformer +from tqdm import tqdm from transformers import ( TrainerCallback, TrainerControl, @@ -14,7 +22,11 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage +if TYPE_CHECKING: + from axolotl.utils.trainer import AxolotlTrainingArguments + LOG = logging.getLogger("axolotl.callbacks") +IGNORE_INDEX = -100 class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods @@ -96,3 +108,89 @@ class GPUStatsCallback( log_gpu_memory_usage(LOG, "while training", self.cfg.device) self.logged = True return control + + +def mmlu_eval_callback_factory(trainer, tokenizer): + accuracy = evaluate.load("accuracy") + abcd_idx = [ + tokenizer("A", add_special_tokens=False).input_ids[0], + tokenizer("B", add_special_tokens=False).input_ids[0], + tokenizer("C", add_special_tokens=False).input_ids[0], + tokenizer("D", add_special_tokens=False).input_ids[0], + ] + mmlu_split = "eval" + if trainer.args.mmlu_dataset == "mmlu-zs": + mmlu_dataset = load_dataset( + "json", + data_files={ + "eval": "data/mmlu/zero_shot_mmlu_val.json", + "test": "data/mmlu/zero_shot_mmlu_test.json", + }, + ) + mmlu_dataset = mmlu_dataset.remove_columns("subject") + # MMLU Five-shot (Eval/Test only) + elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]: + mmlu_dataset = load_dataset( + "json", + data_files={ + "eval": "data/mmlu/five_shot_mmlu_val.json", + "test": "data/mmlu/five_shot_mmlu_test.json", + }, + ) + # mmlu_dataset = mmlu_dataset.remove_columns('subject') + else: + raise ValueError("unhandled value for mmlu_dataset training args") + mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split] + if trainer.args.max_mmlu_samples is not None: + mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples)) + + class MMLUEvalCallback(TrainerCallback): + """ + TrainerCallback that runs the MMLU evals + """ + + def on_evaluate( + self, + args: AxolotlTrainingArguments, + **kwargs, # pylint: disable=unused-argument + ): + data_loader = trainer.get_eval_dataloader(mmlu_dataset) + source_max_len = trainer.data_collator.source_max_len + trainer.data_collator.source_max_len = args.mmlu_source_max_len + trainer.model.eval() + preds, refs = [], [] + loss_mmlu = 0 + for batch in tqdm(data_loader, total=len(data_loader)): + (loss, logits, labels) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + # There are two tokens, the output, and eos token. + for i, logit in enumerate(logits): + label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] + logit_abcd = logit[label_non_zero_id - 1][abcd_idx] + preds.append(torch.argmax(logit_abcd).item()) + labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] + refs += [abcd_idx.index(label) for label in labels.tolist()] + loss_mmlu += loss.item() + # Extract results by subject. + results = {"mmlu_loss": loss_mmlu / len(data_loader)} + subject = mmlu_dataset["subject"] + subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} + for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name + subjects[s]["preds"].append(p) + subjects[s]["refs"].append(r) + subject_scores = [] + for subject in subjects: + subject_score = accuracy.compute( + references=subjects[subject]["refs"], + predictions=subjects[subject]["preds"], + )["accuracy"] + results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score + subject_scores.append(subject_score) + results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores) + trainer.log(results) + trainer.data_collator.source_max_len = source_max_len + + return MMLUEvalCallback diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 24be1b8c2..0229e8081 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -23,6 +23,7 @@ from axolotl.utils.callbacks import ( GPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, + mmlu_eval_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader @@ -127,6 +128,27 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, ) + mmlu_split: Optional[str] = field( + default="eval", metadata={"help": "The MMLU split to run on"} + ) + mmlu_dataset: Optional[str] = field( + default="mmlu-fs", + metadata={ + "help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot." + }, + ) + do_mmlu_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the MMLU evaluation."} + ) + max_mmlu_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset." + }, + ) + mmlu_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for mmlu."} + ) class AxolotlTrainer(Trainer): @@ -517,6 +539,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ "steps" if cfg.save_steps else "epoch" ) + if cfg.do_mmlu_eval: + training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, @@ -631,4 +656,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) + if cfg.do_mmlu_eval: + trainer.add_callback(mmlu_eval_callback_factory(trainer, tokenizer)) + return trainer