include metrics in callback

This commit is contained in:
Wing Lian
2023-08-19 21:26:32 -04:00
parent e3b07402a7
commit 6f166464d8

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Dict
import evaluate
import numpy as np
@@ -154,6 +154,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_eval_dataloader(mmlu_dataset)