diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 2e9280c03..6fa809807 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict import evaluate import numpy as np +import pandas as pd import torch from datasets import load_dataset from optimum.bettertransformer import BetterTransformer @@ -120,6 +121,14 @@ def mmlu_eval_callback_factory(trainer, tokenizer): tokenizer("D", add_special_tokens=False).input_ids[0], ] mmlu_split = "eval" + if trainer.args.mmlu_dataset == "sampled": + mmlu_dataset = load_dataset( + "pharaouk/dharma-1", + data_files={ + "eval": "pharaouk/dharma-1", + }, + ) + # mmlu_dataset = mmlu_dataset.remove_columns("subject") if trainer.args.mmlu_dataset == "mmlu-zs": mmlu_dataset = load_dataset( "openaccess-ai-collective/mmlu-evals", @@ -175,7 +184,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer): with zero_first(is_main_process()): mmlu_dataset = mmlu_dataset.map(tokenize_evals) - class MMLUEvalCallback(TrainerCallback): + class BenchEvalCallback(TrainerCallback): """ TrainerCallback that runs the MMLU evals """ @@ -210,7 +219,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer): refs += [abcd_idx.index(label) for label in labels.tolist()] loss_mmlu += loss.item() # Extract results by subject. - results = {"mmlu_loss": loss_mmlu / len(data_loader)} + results = {"bench_loss": loss_mmlu / len(data_loader)} subject = mmlu_dataset["subject"] subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} for s, p, r in zip( # pylint: disable=invalid-name @@ -224,10 +233,13 @@ def mmlu_eval_callback_factory(trainer, tokenizer): references=subjects[subject]["refs"], predictions=subjects[subject]["preds"], )["accuracy"] - results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score - subject_scores.append(subject_score) - results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores) + if not pd.isna(subject_score): + results[ + f"bench_{mmlu_split}_accuracy_{subject}" + ] = subject_score + subject_scores.append(subject_score) + results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores) trainer.log(results) trainer.data_collator.max_length = source_max_len - return MMLUEvalCallback + return BenchEvalCallback diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 37913915e..2b6895a61 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -541,6 +541,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ if cfg.do_mmlu_eval: training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval + if cfg.mmlu_dataset: + training_arguments_kwargs["mmlu_dataset"] = cfg.mmlu_dataset training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1,