include metrics in callback
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -154,6 +154,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
args: AxolotlTrainingArguments,
|
args: AxolotlTrainingArguments,
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user