compute perplexity from cross entropy

This commit is contained in:
Wing Lian
2023-07-08 12:14:54 -04:00
parent 687d889928
commit 1d02606934

View File

@@ -7,14 +7,15 @@ import os
import sys
from dataclasses import field
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional
import bitsandbytes as bnb
import numpy as np
import torch.cuda
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
@@ -329,6 +330,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_proc=32,
)
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
loss = nn.CrossEntropyLoss()(predictions, labels)
perplexity = np.exp(loss)
return {"perplexity": perplexity}
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
@@ -345,6 +353,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
**data_collator_kwargs,
),
callbacks=callbacks,
compute_metrics=compute_metrics,
**trainer_kwargs,
)