rename mmlu to bench
This commit is contained in:
@@ -112,7 +112,7 @@ class GPUStatsCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def mmlu_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
tokenizer("A", add_special_tokens=False).input_ids[0],
|
tokenizer("A", add_special_tokens=False).input_ids[0],
|
||||||
@@ -123,41 +123,41 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
tokenizer("F", add_special_tokens=False).input_ids[0],
|
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||||
tokenizer("G", add_special_tokens=False).input_ids[0],
|
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||||
]
|
]
|
||||||
mmlu_split = "eval"
|
bench_split = "eval"
|
||||||
if trainer.args.mmlu_dataset == "sampled":
|
if trainer.args.bench_dataset == "sampled":
|
||||||
mmlu_dataset = load_dataset(
|
bench_dataset = load_dataset(
|
||||||
"pharaouk/dharma-1",
|
"pharaouk/dharma-1",
|
||||||
data_files={
|
data_files={
|
||||||
"eval": "dharma_eval.json",
|
"eval": "dharma_eval.json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||||
elif trainer.args.mmlu_dataset == "mmlu-zs":
|
elif trainer.args.bench_dataset == "mmlu-zs":
|
||||||
mmlu_dataset = load_dataset(
|
bench_dataset = load_dataset(
|
||||||
"openaccess-ai-collective/mmlu-evals",
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
data_files={
|
data_files={
|
||||||
"eval": "zero_shot_mmlu_val.json",
|
"eval": "zero_shot_mmlu_val.json",
|
||||||
"test": "zero_shot_mmlu_test.json",
|
"test": "zero_shot_mmlu_test.json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||||
# MMLU Five-shot (Eval/Test only)
|
# MMLU Five-shot (Eval/Test only)
|
||||||
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
|
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
|
||||||
mmlu_dataset = load_dataset(
|
bench_dataset = load_dataset(
|
||||||
"openaccess-ai-collective/mmlu-evals",
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
data_files={
|
data_files={
|
||||||
"eval": "five_shot_mmlu_val.json",
|
"eval": "five_shot_mmlu_val.json",
|
||||||
"test": "five_shot_mmlu_test.json",
|
"test": "five_shot_mmlu_test.json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
|
# bench_dataset = bench_dataset.remove_columns('subject')
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"unhandled value `{trainer.args.mmlu_dataset}` for mmlu_dataset training args"
|
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
||||||
)
|
)
|
||||||
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
|
bench_dataset = bench_dataset[trainer.args.bench_split]
|
||||||
if trainer.args.max_mmlu_samples is not None:
|
if trainer.args.max_bench_samples is not None:
|
||||||
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
|
||||||
|
|
||||||
def tokenize_evals(example):
|
def tokenize_evals(example):
|
||||||
source = f"{tokenizer.bos_token}{example['input']}"
|
source = f"{tokenizer.bos_token}{example['input']}"
|
||||||
@@ -187,8 +187,8 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||||
mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||||
|
|
||||||
class BenchEvalCallback(TrainerCallback):
|
class BenchEvalCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
@@ -204,12 +204,12 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
data_loader = trainer.get_eval_dataloader(bench_dataset)
|
||||||
source_max_len = trainer.data_collator.max_length
|
source_max_len = trainer.data_collator.max_length
|
||||||
trainer.data_collator.max_length = args.mmlu_source_max_len
|
trainer.data_collator.max_length = args.bench_source_max_len
|
||||||
trainer.model.eval()
|
trainer.model.eval()
|
||||||
preds, refs = [], []
|
preds, refs = [], []
|
||||||
loss_mmlu = 0
|
loss_bench = 0
|
||||||
for batch in tqdm(data_loader, total=len(data_loader)):
|
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||||
(loss, logits, labels) = trainer.prediction_step(
|
(loss, logits, labels) = trainer.prediction_step(
|
||||||
trainer.model,
|
trainer.model,
|
||||||
@@ -226,10 +226,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
abcd_idx.index(label) if labels in abcd_idx else -1
|
abcd_idx.index(label) if labels in abcd_idx else -1
|
||||||
for label in labels.tolist()
|
for label in labels.tolist()
|
||||||
]
|
]
|
||||||
loss_mmlu += loss.item()
|
loss_bench += loss.item()
|
||||||
# Extract results by subject.
|
# Extract results by subject.
|
||||||
results = {"bench_loss": loss_mmlu / len(data_loader)}
|
results = {"bench_loss": loss_bench / len(data_loader)}
|
||||||
subject = mmlu_dataset["subject"]
|
subject = bench_dataset["subject"]
|
||||||
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
||||||
for s, p, r in zip( # pylint: disable=invalid-name
|
for s, p, r in zip( # pylint: disable=invalid-name
|
||||||
subject, preds, refs
|
subject, preds, refs
|
||||||
@@ -244,10 +244,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
)["accuracy"]
|
)["accuracy"]
|
||||||
if not pd.isna(subject_score):
|
if not pd.isna(subject_score):
|
||||||
results[
|
results[
|
||||||
f"bench_{mmlu_split}_accuracy_{subject}"
|
f"bench_{bench_split}_accuracy_{subject}"
|
||||||
] = subject_score
|
] = subject_score
|
||||||
subject_scores.append(subject_score)
|
subject_scores.append(subject_score)
|
||||||
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
|
||||||
trainer.log(results)
|
trainer.log(results)
|
||||||
trainer.data_collator.max_length = source_max_len
|
trainer.data_collator.max_length = source_max_len
|
||||||
barrier()
|
barrier()
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from axolotl.utils.callbacks import (
|
|||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
mmlu_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
@@ -128,25 +128,25 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
)
|
)
|
||||||
mmlu_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The MMLU split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
mmlu_dataset: Optional[str] = field(
|
bench_dataset: Optional[str] = field(
|
||||||
default="mmlu-zs",
|
default="mmlu-zs",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
do_mmlu_eval: Optional[bool] = field(
|
do_bench_eval: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Whether to run the MMLU evaluation."}
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
)
|
)
|
||||||
max_mmlu_samples: Optional[int] = field(
|
max_bench_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."
|
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mmlu_source_max_len: int = field(
|
bench_source_max_len: int = field(
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
|
default=2048, metadata={"help": "Maximum source sequence length for mmlu."}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -539,10 +539,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"steps" if cfg.save_steps else "epoch"
|
"steps" if cfg.save_steps else "epoch"
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.do_mmlu_eval:
|
if cfg.do_bench_eval:
|
||||||
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_mmlu_eval
|
training_arguments_kwargs["do_mmlu_eval"] = cfg.do_bench_eval
|
||||||
if cfg.mmlu_dataset:
|
if cfg.bench_dataset:
|
||||||
training_arguments_kwargs["mmlu_dataset"] = cfg.mmlu_dataset
|
training_arguments_kwargs["mmlu_dataset"] = cfg.bench_dataset
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
@@ -658,7 +658,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.do_mmlu_eval:
|
if cfg.do_bench_eval:
|
||||||
trainer.add_callback(mmlu_eval_callback_factory(trainer, tokenizer))
|
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user