use different perplexity calc
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-07-10 13:43:50 -04:00
parent e7c84254ba
commit 0f2a16aa33

View File

@@ -10,12 +10,17 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import torch.nn.functional as F
import transformers import transformers
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments from transformers import (
EarlyStoppingCallback,
EvalPrediction,
Trainer,
TrainingArguments,
)
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
@@ -333,19 +338,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.compute_perplexity_metrics: if cfg.compute_perplexity_metrics:
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds logits = eval_preds.predictions
# Convert numpy ndarrays to PyTorch tensors labels = eval_preds.label_ids
logits_tensor = torch.tensor(logits) cross_entropy_loss = F.cross_entropy(
labels_tensor = torch.tensor(labels) logits.view(-1, model.config.vocab_size), labels.view(-1)
# Adjust labels to match expected size
labels_tensor = labels_tensor.view(-1)
loss = nn.CrossEntropyLoss()(
logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor
) )
perplexity = np.exp( perplexity = torch.exp(cross_entropy_loss)
loss.item() return {"cross_entropy_loss": cross_entropy_loss, "perplexity": perplexity}
) # Use .item() to get a Python number from a tensor containing a single value
return {"perplexity": perplexity}
trainer_kwargs["compute_metrics"] = compute_metrics trainer_kwargs["compute_metrics"] = compute_metrics