dataset handling and aggregate across benchmark

This commit is contained in:
Wing Lian
2023-08-21 16:56:40 -04:00
parent 2455254b92
commit 24b0e93235

View File

@@ -125,13 +125,25 @@ def bench_eval_callback_factory(trainer, tokenizer):
]
bench_split = "eval"
if trainer.args.bench_dataset == "sampled":
def transform_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = parts[0].strip().lower() # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_")
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
bench_dataset = load_dataset(
"pharaouk/dharma-1",
data_files={
"eval": "dharma_eval.json",
},
)
# bench_dataset = bench_dataset.remove_columns("subject")
bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject)
elif trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
@@ -228,21 +240,21 @@ def bench_eval_callback_factory(trainer, tokenizer):
loss_bench += loss.item()
# Extract results by subject.
results = {"bench_loss": loss_bench / len(data_loader)}
subject = bench_dataset["subject"]
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name
subjects[s]["preds"].append(p)
subjects[s]["refs"].append(r)
subject_scores = []
for subject in subjects:
subject_score = accuracy.compute(
references=subjects[subject]["refs"],
predictions=subjects[subject]["preds"],
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
bench_scores = []
for bench_name in bench_names:
bench_score = accuracy.compute(
references=bench_names[bench_name]["refs"],
predictions=bench_names[bench_name]["preds"],
)["accuracy"]
if not pd.isna(subject_score):
results[f"bench_{bench_split}_accuracy_{subject}"] = subject_score
subject_scores.append(subject_score)
results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores)
if not pd.isna(bench_score):
results[f"bench_{bench_split}_accuracy_{bench_name}"] = bench_score
bench_scores.append(bench_score)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
trainer.data_collator.max_length = source_max_len