compute perplexity from cross entropy
This commit is contained in:
@@ -7,14 +7,15 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
import numpy as np
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, EvalPrediction, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
@@ -329,6 +330,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
num_proc=32,
|
num_proc=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
|
||||||
|
logits, labels = eval_preds
|
||||||
|
predictions = np.argmax(logits, axis=-1)
|
||||||
|
loss = nn.CrossEntropyLoss()(predictions, labels)
|
||||||
|
perplexity = np.exp(loss)
|
||||||
|
return {"perplexity": perplexity}
|
||||||
|
|
||||||
trainer_cls = (
|
trainer_cls = (
|
||||||
OneCycleLRSchedulerTrainer
|
OneCycleLRSchedulerTrainer
|
||||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||||
@@ -345,6 +353,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
),
|
),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user