Compare commits
3 Commits
fix/hpc-ro
...
compute-pe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f2a16aa33 | ||
|
|
e7c84254ba | ||
|
|
1d02606934 |
@@ -7,14 +7,20 @@ import os
|
||||
import sys
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch.cuda
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
EvalPrediction,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.callbacks import (
|
||||
@@ -329,6 +335,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
if cfg.compute_perplexity_metrics:
|
||||
|
||||
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
|
||||
logits = eval_preds.predictions
|
||||
labels = eval_preds.label_ids
|
||||
cross_entropy_loss = F.cross_entropy(
|
||||
logits.view(-1, model.config.vocab_size), labels.view(-1)
|
||||
)
|
||||
perplexity = torch.exp(cross_entropy_loss)
|
||||
return {"cross_entropy_loss": cross_entropy_loss, "perplexity": perplexity}
|
||||
|
||||
trainer_kwargs["compute_metrics"] = compute_metrics
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
|
||||
Reference in New Issue
Block a user