use different perplexity calc
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-07-10 13:43:50 -04:00
parent e7c84254ba
commit 0f2a16aa33

View File

@@ -10,12 +10,17 @@ from pathlib import Path
from typing import Any, Dict, Optional
import bitsandbytes as bnb
import numpy as np
import torch.cuda
import torch.nn.functional as F
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments
from transformers import (
EarlyStoppingCallback,
EvalPrediction,
Trainer,
TrainingArguments,
)
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
@@ -333,19 +338,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.compute_perplexity_metrics:
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds
# Convert numpy ndarrays to PyTorch tensors
logits_tensor = torch.tensor(logits)
labels_tensor = torch.tensor(labels)
# Adjust labels to match expected size
labels_tensor = labels_tensor.view(-1)
loss = nn.CrossEntropyLoss()(
logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor
logits = eval_preds.predictions
labels = eval_preds.label_ids
cross_entropy_loss = F.cross_entropy(
logits.view(-1, model.config.vocab_size), labels.view(-1)
)
perplexity = np.exp(
loss.item()
) # Use .item() to get a Python number from a tensor containing a single value
return {"perplexity": perplexity}
perplexity = torch.exp(cross_entropy_loss)
return {"cross_entropy_loss": cross_entropy_loss, "perplexity": perplexity}
trainer_kwargs["compute_metrics"] = compute_metrics