sample benchmarks, ensure we drop long samples
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -120,6 +121,14 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
tokenizer("D", add_special_tokens=False).input_ids[0],
|
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||||
]
|
]
|
||||||
mmlu_split = "eval"
|
mmlu_split = "eval"
|
||||||
|
if trainer.args.mmlu_dataset == "sampled":
|
||||||
|
mmlu_dataset = load_dataset(
|
||||||
|
"pharaouk/dharma-1",
|
||||||
|
data_files={
|
||||||
|
"eval": "pharaouk/dharma-1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
||||||
if trainer.args.mmlu_dataset == "mmlu-zs":
|
if trainer.args.mmlu_dataset == "mmlu-zs":
|
||||||
mmlu_dataset = load_dataset(
|
mmlu_dataset = load_dataset(
|
||||||
"openaccess-ai-collective/mmlu-evals",
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
@@ -175,7 +184,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
||||||
|
|
||||||
class MMLUEvalCallback(TrainerCallback):
|
class BenchEvalCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
TrainerCallback that runs the MMLU evals
|
TrainerCallback that runs the MMLU evals
|
||||||
"""
|
"""
|
||||||
@@ -210,7 +219,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
||||||
loss_mmlu += loss.item()
|
loss_mmlu += loss.item()
|
||||||
# Extract results by subject.
|
# Extract results by subject.
|
||||||
results = {"mmlu_loss": loss_mmlu / len(data_loader)}
|
results = {"bench_loss": loss_mmlu / len(data_loader)}
|
||||||
subject = mmlu_dataset["subject"]
|
subject = mmlu_dataset["subject"]
|
||||||
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
||||||
for s, p, r in zip( # pylint: disable=invalid-name
|
for s, p, r in zip( # pylint: disable=invalid-name
|
||||||
@@ -224,10 +233,13 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
references=subjects[subject]["refs"],
|
references=subjects[subject]["refs"],
|
||||||
predictions=subjects[subject]["preds"],
|
predictions=subjects[subject]["preds"],
|
||||||
)["accuracy"]
|
)["accuracy"]
|
||||||
results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score
|
if not pd.isna(subject_score):
|
||||||
subject_scores.append(subject_score)
|
results[
|
||||||
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
f"bench_{mmlu_split}_accuracy_{subject}"
|
||||||
|
] = subject_score
|
||||||
|
subject_scores.append(subject_score)
|
||||||
|
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||||
trainer.log(results)
|
trainer.log(results)
|
||||||
trainer.data_collator.max_length = source_max_len
|
trainer.data_collator.max_length = source_max_len
|
||||||
|
|
||||||
return MMLUEvalCallback
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -541,6 +541,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
|
|
||||||
if cfg.do_mmlu_eval:
|
if cfg.do_mmlu_eval:
|
||||||
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
|
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
|
||||||
|
if cfg.mmlu_dataset:
|
||||||
|
training_arguments_kwargs["mmlu_dataset"] = cfg.mmlu_dataset
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
|
|||||||
Reference in New Issue
Block a user