Changed Bench Eval to report metrics correctly by split. Added total accuracy and renamed previously used bench_accuracy to bench_average_accuracy. (#512)

* Added "eval_" prefix

* Added total bench accuracy and renamed the previous one to bench_average_accuracy. Changed naming to use bench_split instead of always using eval_ prefix.
This commit is contained in:
Alpay Ariyak
2023-08-31 01:00:50 -04:00
committed by GitHub
parent c56b450cf5
commit 42f9642792

View File

@@ -275,7 +275,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {"bench_loss": bench_loss}
results = {f"{bench_split}_bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
@@ -287,6 +287,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
bench_refs = []
bench_preds = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
@@ -294,15 +296,20 @@ def bench_eval_callback_factory(trainer, tokenizer):
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"bench_{bench_split}_accuracy_{bench_name}"
f"{bench_split}_bench_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
references=bench_refs, predictions=bench_preds
)["accuracy"]
trainer.log(results)
return BenchEvalCallback