Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
0f2a16aa33 use different perplexity calc
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-10 13:43:50 -04:00
Wing Lian
e7c84254ba fix perplexity calculation and make it configurable 2023-07-10 12:49:51 -04:00
Wing Lian
1d02606934 compute perplexity from cross entropy 2023-07-10 12:49:47 -04:00

View File

@@ -7,14 +7,20 @@ import os
import sys
from dataclasses import field
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional
import bitsandbytes as bnb
import torch.cuda
import torch.nn.functional as F
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers import (
EarlyStoppingCallback,
EvalPrediction,
Trainer,
TrainingArguments,
)
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
@@ -329,6 +335,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_proc=32,
)
if cfg.compute_perplexity_metrics:
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits = eval_preds.predictions
labels = eval_preds.label_ids
cross_entropy_loss = F.cross_entropy(
logits.view(-1, model.config.vocab_size), labels.view(-1)
)
perplexity = torch.exp(cross_entropy_loss)
return {"cross_entropy_loss": cross_entropy_loss, "perplexity": perplexity}
trainer_kwargs["compute_metrics"] = compute_metrics
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")