compute perplexity from cross entropy

This commit is contained in:
Wing Lian
2023-07-08 12:14:54 -04:00
parent 687d889928
commit 1d02606934

View File

@@ -7,14 +7,15 @@ import os
import sys import sys
from dataclasses import field from dataclasses import field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Dict, Optional
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
@@ -329,6 +330,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_proc=32, num_proc=32,
) )
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
loss = nn.CrossEntropyLoss()(predictions, labels)
perplexity = np.exp(loss)
return {"perplexity": perplexity}
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
@@ -345,6 +353,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
**data_collator_kwargs, **data_collator_kwargs,
), ),
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_metrics,
**trainer_kwargs, **trainer_kwargs,
) )