rename mmlu to bench
This commit is contained in:
@@ -112,7 +112,7 @@ class GPUStatsCallback(
|
||||
return control
|
||||
|
||||
|
||||
def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
tokenizer("A", add_special_tokens=False).input_ids[0],
|
||||
@@ -123,41 +123,41 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||
]
|
||||
mmlu_split = "eval"
|
||||
if trainer.args.mmlu_dataset == "sampled":
|
||||
mmlu_dataset = load_dataset(
|
||||
bench_split = "eval"
|
||||
if trainer.args.bench_dataset == "sampled":
|
||||
bench_dataset = load_dataset(
|
||||
"pharaouk/dharma-1",
|
||||
data_files={
|
||||
"eval": "dharma_eval.json",
|
||||
},
|
||||
)
|
||||
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
||||
elif trainer.args.mmlu_dataset == "mmlu-zs":
|
||||
mmlu_dataset = load_dataset(
|
||||
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||
elif trainer.args.bench_dataset == "mmlu-zs":
|
||||
bench_dataset = load_dataset(
|
||||
"openaccess-ai-collective/mmlu-evals",
|
||||
data_files={
|
||||
"eval": "zero_shot_mmlu_val.json",
|
||||
"test": "zero_shot_mmlu_test.json",
|
||||
},
|
||||
)
|
||||
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
||||
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||
# MMLU Five-shot (Eval/Test only)
|
||||
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
|
||||
mmlu_dataset = load_dataset(
|
||||
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
|
||||
bench_dataset = load_dataset(
|
||||
"openaccess-ai-collective/mmlu-evals",
|
||||
data_files={
|
||||
"eval": "five_shot_mmlu_val.json",
|
||||
"test": "five_shot_mmlu_test.json",
|
||||
},
|
||||
)
|
||||
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
|
||||
# bench_dataset = bench_dataset.remove_columns('subject')
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unhandled value `{trainer.args.mmlu_dataset}` for mmlu_dataset training args"
|
||||
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
||||
)
|
||||
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
|
||||
if trainer.args.max_mmlu_samples is not None:
|
||||
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
||||
bench_dataset = bench_dataset[trainer.args.bench_split]
|
||||
if trainer.args.max_bench_samples is not None:
|
||||
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
|
||||
|
||||
def tokenize_evals(example):
|
||||
source = f"{tokenizer.bos_token}{example['input']}"
|
||||
@@ -187,8 +187,8 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
}
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
||||
mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||
|
||||
class BenchEvalCallback(TrainerCallback):
|
||||
"""
|
||||
@@ -204,12 +204,12 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||
data_loader = trainer.get_eval_dataloader(bench_dataset)
|
||||
source_max_len = trainer.data_collator.max_length
|
||||
trainer.data_collator.max_length = args.mmlu_source_max_len
|
||||
trainer.data_collator.max_length = args.bench_source_max_len
|
||||
trainer.model.eval()
|
||||
preds, refs = [], []
|
||||
loss_mmlu = 0
|
||||
loss_bench = 0
|
||||
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||
(loss, logits, labels) = trainer.prediction_step(
|
||||
trainer.model,
|
||||
@@ -226,10 +226,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
abcd_idx.index(label) if labels in abcd_idx else -1
|
||||
for label in labels.tolist()
|
||||
]
|
||||
loss_mmlu += loss.item()
|
||||
loss_bench += loss.item()
|
||||
# Extract results by subject.
|
||||
results = {"bench_loss": loss_mmlu / len(data_loader)}
|
||||
subject = mmlu_dataset["subject"]
|
||||
results = {"bench_loss": loss_bench / len(data_loader)}
|
||||
subject = bench_dataset["subject"]
|
||||
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
||||
for s, p, r in zip( # pylint: disable=invalid-name
|
||||
subject, preds, refs
|
||||
@@ -244,10 +244,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
)["accuracy"]
|
||||
if not pd.isna(subject_score):
|
||||
results[
|
||||
f"bench_{mmlu_split}_accuracy_{subject}"
|
||||
f"bench_{bench_split}_accuracy_{subject}"
|
||||
] = subject_score
|
||||
subject_scores.append(subject_score)
|
||||
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
|
||||
trainer.log(results)
|
||||
trainer.data_collator.max_length = source_max_len
|
||||
barrier()
|
||||
|
||||
@@ -23,7 +23,7 @@ from axolotl.utils.callbacks import (
|
||||
GPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
mmlu_eval_callback_factory,
|
||||
bench_eval_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
@@ -128,25 +128,25 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
mmlu_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The MMLU split to run on"}
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
mmlu_dataset: Optional[str] = field(
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="mmlu-zs",
|
||||
metadata={
|
||||
"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
|
||||
},
|
||||
)
|
||||
do_mmlu_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the MMLU evaluation."}
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_mmlu_samples: Optional[int] = field(
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
mmlu_source_max_len: int = field(
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
|
||||
)
|
||||
|
||||
@@ -539,10 +539,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
"steps" if cfg.save_steps else "epoch"
|
||||
)
|
||||
|
||||
if cfg.do_mmlu_eval:
|
||||
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
|
||||
if cfg.mmlu_dataset:
|
||||
training_arguments_kwargs["mmlu_dataset"] = cfg.mmlu_dataset
|
||||
if cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_bench_eval
|
||||
if cfg.bench_dataset:
|
||||
training_arguments_kwargs["mmlu_dataset"] = cfg.bench_dataset
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||
@@ -658,7 +658,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
if cfg.do_mmlu_eval:
|
||||
trainer.add_callback(mmlu_eval_callback_factory(trainer, tokenizer))
|
||||
if cfg.do_bench_eval:
|
||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||
|
||||
return trainer
|
||||
|
||||
Reference in New Issue
Block a user