use different perplexity calc
This commit is contained in:
@@ -10,12 +10,17 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import numpy as np
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments
|
from transformers import (
|
||||||
|
EarlyStoppingCallback,
|
||||||
|
EvalPrediction,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
@@ -333,19 +338,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.compute_perplexity_metrics:
|
if cfg.compute_perplexity_metrics:
|
||||||
|
|
||||||
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
|
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
|
||||||
logits, labels = eval_preds
|
logits = eval_preds.predictions
|
||||||
# Convert numpy ndarrays to PyTorch tensors
|
labels = eval_preds.label_ids
|
||||||
logits_tensor = torch.tensor(logits)
|
cross_entropy_loss = F.cross_entropy(
|
||||||
labels_tensor = torch.tensor(labels)
|
logits.view(-1, model.config.vocab_size), labels.view(-1)
|
||||||
# Adjust labels to match expected size
|
|
||||||
labels_tensor = labels_tensor.view(-1)
|
|
||||||
loss = nn.CrossEntropyLoss()(
|
|
||||||
logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor
|
|
||||||
)
|
)
|
||||||
perplexity = np.exp(
|
perplexity = torch.exp(cross_entropy_loss)
|
||||||
loss.item()
|
return {"cross_entropy_loss": cross_entropy_loss, "perplexity": perplexity}
|
||||||
) # Use .item() to get a Python number from a tensor containing a single value
|
|
||||||
return {"perplexity": perplexity}
|
|
||||||
|
|
||||||
trainer_kwargs["compute_metrics"] = compute_metrics
|
trainer_kwargs["compute_metrics"] = compute_metrics
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user