rename mmlu to bench

This commit is contained in:
Wing Lian
2023-08-21 04:38:51 -04:00
parent ef062d8fcb
commit 918e040601
2 changed files with 41 additions and 41 deletions

View File

@@ -112,7 +112,7 @@ class GPUStatsCallback(
return control
def mmlu_eval_callback_factory(trainer, tokenizer):
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
@@ -123,41 +123,41 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
mmlu_split = "eval"
if trainer.args.mmlu_dataset == "sampled":
mmlu_dataset = load_dataset(
bench_split = "eval"
if trainer.args.bench_dataset == "sampled":
bench_dataset = load_dataset(
"pharaouk/dharma-1",
data_files={
"eval": "dharma_eval.json",
},
)
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
elif trainer.args.mmlu_dataset == "mmlu-zs":
mmlu_dataset = load_dataset(
# bench_dataset = bench_dataset.remove_columns("subject")
elif trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "zero_shot_mmlu_val.json",
"test": "zero_shot_mmlu_test.json",
},
)
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
# bench_dataset = bench_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
mmlu_dataset = load_dataset(
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "five_shot_mmlu_val.json",
"test": "five_shot_mmlu_test.json",
},
)
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
# bench_dataset = bench_dataset.remove_columns('subject')
else:
raise ValueError(
f"unhandled value `{trainer.args.mmlu_dataset}` for mmlu_dataset training args"
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
)
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
if trainer.args.max_mmlu_samples is not None:
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
bench_dataset = bench_dataset[trainer.args.bench_split]
if trainer.args.max_bench_samples is not None:
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
def tokenize_evals(example):
source = f"{tokenizer.bos_token}{example['input']}"
@@ -187,8 +187,8 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
}
with zero_first(is_main_process()):
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
@@ -204,12 +204,12 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
data_loader = trainer.get_eval_dataloader(bench_dataset)
source_max_len = trainer.data_collator.max_length
trainer.data_collator.max_length = args.mmlu_source_max_len
trainer.data_collator.max_length = args.bench_source_max_len
trainer.model.eval()
preds, refs = [], []
loss_mmlu = 0
loss_bench = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
@@ -226,10 +226,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
abcd_idx.index(label) if labels in abcd_idx else -1
for label in labels.tolist()
]
loss_mmlu += loss.item()
loss_bench += loss.item()
# Extract results by subject.
results = {"bench_loss": loss_mmlu / len(data_loader)}
subject = mmlu_dataset["subject"]
results = {"bench_loss": loss_bench / len(data_loader)}
subject = bench_dataset["subject"]
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
for s, p, r in zip( # pylint: disable=invalid-name
subject, preds, refs
@@ -244,10 +244,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
)["accuracy"]
if not pd.isna(subject_score):
results[
f"bench_{mmlu_split}_accuracy_{subject}"
f"bench_{bench_split}_accuracy_{subject}"
] = subject_score
subject_scores.append(subject_score)
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
trainer.log(results)
trainer.data_collator.max_length = source_max_len
barrier()

View File

@@ -23,7 +23,7 @@ from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
mmlu_eval_callback_factory,
bench_eval_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -128,25 +128,25 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
mmlu_split: Optional[str] = field(
default="eval", metadata={"help": "The MMLU split to run on"}
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
mmlu_dataset: Optional[str] = field(
bench_dataset: Optional[str] = field(
default="mmlu-zs",
metadata={
"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
},
)
do_mmlu_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the MMLU evaluation."}
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_mmlu_samples: Optional[int] = field(
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
mmlu_source_max_len: int = field(
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
)
@@ -539,10 +539,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"steps" if cfg.save_steps else "epoch"
)
if cfg.do_mmlu_eval:
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
if cfg.mmlu_dataset:
training_arguments_kwargs["mmlu_dataset"] = cfg.mmlu_dataset
if cfg.do_bench_eval:
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["mmlu_dataset"] = cfg.bench_dataset
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
@@ -658,7 +658,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)
if cfg.do_mmlu_eval:
trainer.add_callback(mmlu_eval_callback_factory(trainer, tokenizer))
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
return trainer