From 1d02606934d506cda70cdaa474fe20d343e56c96 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 8 Jul 2023 12:14:54 -0400 Subject: [PATCH] compute perplexity from cross entropy --- src/axolotl/utils/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 98ff9b3b9..103a38715 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -7,14 +7,15 @@ import os import sys from dataclasses import field from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional import bitsandbytes as bnb +import numpy as np import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments +from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( @@ -329,6 +330,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): num_proc=32, ) + def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: + logits, labels = eval_preds + predictions = np.argmax(logits, axis=-1) + loss = nn.CrossEntropyLoss()(predictions, labels) + perplexity = np.exp(loss) + return {"perplexity": perplexity} + trainer_cls = ( OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") @@ -345,6 +353,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): **data_collator_kwargs, ), callbacks=callbacks, + compute_metrics=compute_metrics, **trainer_kwargs, )