diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 92333f4ca..ee5acfd55 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -275,7 +275,7 @@ def bench_eval_callback_factory(trainer, tokenizer): else: dist.gather_object(local_bench_names, gathered_bench_names, dst=0) bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) - results = {"bench_loss": bench_loss} + results = {f"{bench_split}_bench_loss": bench_loss} # Combine results from all GPUs combined_bench_names: Dict[str, Dict[str, List]] = {} @@ -287,6 +287,8 @@ def bench_eval_callback_factory(trainer, tokenizer): combined_bench_names[name]["preds"].extend(data["preds"]) bench_scores = [] + bench_refs = [] + bench_preds = [] for ( bench_name ) in combined_bench_names: # pylint: disable=consider-using-dict-items @@ -294,15 +296,20 @@ def bench_eval_callback_factory(trainer, tokenizer): references=combined_bench_names[bench_name]["refs"], predictions=combined_bench_names[bench_name]["preds"], )["accuracy"] + bench_refs.extend(combined_bench_names[bench_name]["refs"]) + bench_preds.extend(combined_bench_names[bench_name]["preds"]) if not pd.isna(bench_score): results[ - f"bench_{bench_split}_accuracy_{bench_name}" + f"{bench_split}_bench_accuracy_{bench_name}" ] = bench_score bench_scores.append(bench_score) else: - results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0 + results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 bench_scores.append(0.0) - results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) + results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores) + results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute( + references=bench_refs, predictions=bench_preds + )["accuracy"] trainer.log(results) return BenchEvalCallback