improve support for customized dataset for bench evals
This commit is contained in:
@@ -124,29 +124,21 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
tokenizer("G", add_special_tokens=False).input_ids[0],
|
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||||
]
|
]
|
||||||
bench_split = "eval"
|
bench_split = "eval"
|
||||||
if trainer.args.bench_dataset == "sampled":
|
|
||||||
|
|
||||||
def transform_subject(example):
|
def transform_bench_subject(example):
|
||||||
# Split on ':' and trim whitespace
|
# Split on ':' and trim whitespace
|
||||||
parts = example["subject"].split(":")
|
parts = example["subject"].split(":")
|
||||||
first_part = (
|
first_part = (
|
||||||
parts[0].strip().lower().replace("-", "_")
|
parts[0].strip().lower().replace("-", "_")
|
||||||
) # Lowercase the first part
|
) # Lowercase the first part
|
||||||
second_part = (
|
second_part = (
|
||||||
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
|
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
|
||||||
) # Replace hyphens with underscores
|
) # Replace hyphens with underscores
|
||||||
|
|
||||||
# Return the transformed values
|
# Return the transformed values
|
||||||
return {"name": first_part, "subject": second_part}
|
return {"name": first_part, "subject": second_part}
|
||||||
|
|
||||||
bench_dataset = load_dataset(
|
if trainer.args.bench_dataset == "mmlu-zs":
|
||||||
"pharaouk/dharma-1",
|
|
||||||
data_files={
|
|
||||||
"eval": "dharma_1_mini.json",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
|
|
||||||
elif trainer.args.bench_dataset == "mmlu-zs":
|
|
||||||
bench_dataset = load_dataset(
|
bench_dataset = load_dataset(
|
||||||
"openaccess-ai-collective/mmlu-evals",
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
data_files={
|
data_files={
|
||||||
@@ -165,6 +157,17 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
# bench_dataset = bench_dataset.remove_columns('subject')
|
# bench_dataset = bench_dataset.remove_columns('subject')
|
||||||
|
elif "/" in trainer.args.bench_dataset:
|
||||||
|
bench_ds = trainer.args.bench_dataset
|
||||||
|
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
|
||||||
|
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
|
||||||
|
bench_dataset = load_dataset(
|
||||||
|
bench_ds_name,
|
||||||
|
data_files={
|
||||||
|
"eval": bench_ds_data_file,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
||||||
|
|||||||
@@ -138,9 +138,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
bench_dataset: Optional[str] = field(
|
bench_dataset: Optional[str] = field(
|
||||||
default="sampled",
|
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, `sampled`"
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
do_bench_eval: Optional[bool] = field(
|
do_bench_eval: Optional[bool] = field(
|
||||||
|
|||||||
Reference in New Issue
Block a user