include metrics in callback

This commit is contained in:
Wing Lian
2023-08-19 21:26:32 -04:00
parent e3b07402a7
commit 6f166464d8

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Dict
import evaluate import evaluate
import numpy as np import numpy as np
@@ -154,6 +154,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
args: AxolotlTrainingArguments, args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
data_loader = trainer.get_eval_dataloader(mmlu_dataset) data_loader = trainer.get_eval_dataloader(mmlu_dataset)