From d6cea180343b5fe852a478274289596eee2806b2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 28 Aug 2023 06:03:53 -0400 Subject: [PATCH] improve support for customized dataset for bench evals --- src/axolotl/utils/callbacks.py | 43 ++++++++++++++++++---------------- src/axolotl/utils/trainer.py | 4 ++-- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 7901f3f62..1896df2de 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -124,29 +124,21 @@ def bench_eval_callback_factory(trainer, tokenizer): tokenizer("G", add_special_tokens=False).input_ids[0], ] bench_split = "eval" - if trainer.args.bench_dataset == "sampled": - def transform_subject(example): - # Split on ':' and trim whitespace - parts = example["subject"].split(":") - first_part = ( - parts[0].strip().lower().replace("-", "_") - ) # Lowercase the first part - second_part = ( - parts[1].strip().replace("-", "_") if len(parts) > 1 else "all" - ) # Replace hyphens with underscores + def transform_bench_subject(example): + # Split on ':' and trim whitespace + parts = example["subject"].split(":") + first_part = ( + parts[0].strip().lower().replace("-", "_") + ) # Lowercase the first part + second_part = ( + parts[1].strip().replace("-", "_") if len(parts) > 1 else "all" + ) # Replace hyphens with underscores - # Return the transformed values - return {"name": first_part, "subject": second_part} + # Return the transformed values + return {"name": first_part, "subject": second_part} - bench_dataset = load_dataset( - "pharaouk/dharma-1", - data_files={ - "eval": "dharma_1_mini.json", - }, - ) - bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject) - elif trainer.args.bench_dataset == "mmlu-zs": + if trainer.args.bench_dataset == "mmlu-zs": bench_dataset = load_dataset( "openaccess-ai-collective/mmlu-evals", data_files={ @@ -165,6 +157,17 @@ def bench_eval_callback_factory(trainer, tokenizer): }, ) # bench_dataset = bench_dataset.remove_columns('subject') + elif "/" in trainer.args.bench_dataset: + bench_ds = trainer.args.bench_dataset + bench_ds_name = "/".join(bench_ds.split("/", 2)[:2]) + bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:]) + bench_dataset = load_dataset( + bench_ds_name, + data_files={ + "eval": bench_ds_data_file, + }, + ) + bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject) else: raise ValueError( f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8ba12cbbf..bac59068a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -138,9 +138,9 @@ class AxolotlTrainingArguments(TrainingArguments): default="eval", metadata={"help": "The benchmark split to run on"} ) bench_dataset: Optional[str] = field( - default="sampled", + default="pharaouk/dharma-1/dharma_1_mini.json", metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`" + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" }, ) do_bench_eval: Optional[bool] = field(