improve support for customized dataset for bench evals

This commit is contained in:
Wing Lian
2023-08-28 06:03:53 -04:00
parent 606846e0a5
commit d6cea18034
2 changed files with 25 additions and 22 deletions

View File

@@ -124,29 +124,21 @@ def bench_eval_callback_factory(trainer, tokenizer):
tokenizer("G", add_special_tokens=False).input_ids[0],
]
bench_split = "eval"
if trainer.args.bench_dataset == "sampled":
def transform_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
def transform_bench_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
# Return the transformed values
return {"name": first_part, "subject": second_part}
bench_dataset = load_dataset(
"pharaouk/dharma-1",
data_files={
"eval": "dharma_1_mini.json",
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
elif trainer.args.bench_dataset == "mmlu-zs":
if trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
@@ -165,6 +157,17 @@ def bench_eval_callback_factory(trainer, tokenizer):
},
)
# bench_dataset = bench_dataset.remove_columns('subject')
elif "/" in trainer.args.bench_dataset:
bench_ds = trainer.args.bench_dataset
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
bench_dataset = load_dataset(
bench_ds_name,
data_files={
"eval": bench_ds_data_file,
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
else:
raise ValueError(
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"

View File

@@ -138,9 +138,9 @@ class AxolotlTrainingArguments(TrainingArguments):
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="sampled",
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(