From 0f2a16aa33faa5ec73bb783d3e0ac17cb7cd89d9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 10 Jul 2023 13:43:50 -0400 Subject: [PATCH] use different perplexity calc --- src/axolotl/utils/trainer.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 09dcf30de..f939576c9 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -10,12 +10,17 @@ from pathlib import Path from typing import Any, Dict, Optional import bitsandbytes as bnb -import numpy as np import torch.cuda +import torch.nn.functional as F import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments +from transformers import ( + EarlyStoppingCallback, + EvalPrediction, + Trainer, + TrainingArguments, +) from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( @@ -333,19 +338,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.compute_perplexity_metrics: def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: - logits, labels = eval_preds - # Convert numpy ndarrays to PyTorch tensors - logits_tensor = torch.tensor(logits) - labels_tensor = torch.tensor(labels) - # Adjust labels to match expected size - labels_tensor = labels_tensor.view(-1) - loss = nn.CrossEntropyLoss()( - logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor + logits = eval_preds.predictions + labels = eval_preds.label_ids + cross_entropy_loss = F.cross_entropy( + logits.view(-1, model.config.vocab_size), labels.view(-1) ) - perplexity = np.exp( - loss.item() - ) # Use .item() to get a Python number from a tensor containing a single value - return {"perplexity": perplexity} + perplexity = torch.exp(cross_entropy_loss) + return {"cross_entropy_loss": cross_entropy_loss, "perplexity": perplexity} trainer_kwargs["compute_metrics"] = compute_metrics