diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 7039089b8..2f632a294 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -125,13 +125,25 @@ def bench_eval_callback_factory(trainer, tokenizer): ] bench_split = "eval" if trainer.args.bench_dataset == "sampled": + + def transform_subject(example): + # Split on ':' and trim whitespace + parts = example["subject"].split(":") + first_part = parts[0].strip().lower() # Lowercase the first part + second_part = ( + parts[1].strip().replace("-", "_") + ) # Replace hyphens with underscores + + # Return the transformed values + return {"name": first_part, "subject": second_part} + bench_dataset = load_dataset( "pharaouk/dharma-1", data_files={ "eval": "dharma_eval.json", }, ) - # bench_dataset = bench_dataset.remove_columns("subject") + bench_dataset["eval"] = bench_dataset["eval"].map(transform_subject) elif trainer.args.bench_dataset == "mmlu-zs": bench_dataset = load_dataset( "openaccess-ai-collective/mmlu-evals", @@ -228,21 +240,21 @@ def bench_eval_callback_factory(trainer, tokenizer): loss_bench += loss.item() # Extract results by subject. results = {"bench_loss": loss_bench / len(data_loader)} - subject = bench_dataset["subject"] - subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} - for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name - subjects[s]["preds"].append(p) - subjects[s]["refs"].append(r) - subject_scores = [] - for subject in subjects: - subject_score = accuracy.compute( - references=subjects[subject]["refs"], - predictions=subjects[subject]["preds"], + bench_name = bench_dataset["name"] + bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} + for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name + bench_names[s]["preds"].append(p) + bench_names[s]["refs"].append(r) + bench_scores = [] + for bench_name in bench_names: + bench_score = accuracy.compute( + references=bench_names[bench_name]["refs"], + predictions=bench_names[bench_name]["preds"], )["accuracy"] - if not pd.isna(subject_score): - results[f"bench_{bench_split}_accuracy_{subject}"] = subject_score - subject_scores.append(subject_score) - results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores) + if not pd.isna(bench_score): + results[f"bench_{bench_split}_accuracy_{bench_name}"] = bench_score + bench_scores.append(bench_score) + results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) trainer.log(results) trainer.data_collator.max_length = source_max_len