include metrics in callback
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
@@ -154,6 +154,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
args: AxolotlTrainingArguments,
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||
|
||||
Reference in New Issue
Block a user