diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 98ff9b3b9..103a38715 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -7,14 +7,15 @@ import os import sys from dataclasses import field from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional import bitsandbytes as bnb +import numpy as np import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments +from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( @@ -329,6 +330,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): num_proc=32, ) + def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: + logits, labels = eval_preds + predictions = np.argmax(logits, axis=-1) + loss = nn.CrossEntropyLoss()(predictions, labels) + perplexity = np.exp(loss) + return {"perplexity": perplexity} + trainer_cls = ( OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") @@ -345,6 +353,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): **data_collator_kwargs, ), callbacks=callbacks, + compute_metrics=compute_metrics, **trainer_kwargs, )