fix perplexity calculation and make it configurable

This commit is contained in:
Wing Lian
2023-07-08 13:35:24 -04:00
parent 1d02606934
commit e7c84254ba

View File

@@ -330,12 +330,24 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_proc=32,
)
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
loss = nn.CrossEntropyLoss()(predictions, labels)
perplexity = np.exp(loss)
return {"perplexity": perplexity}
if cfg.compute_perplexity_metrics:
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds
# Convert numpy ndarrays to PyTorch tensors
logits_tensor = torch.tensor(logits)
labels_tensor = torch.tensor(labels)
# Adjust labels to match expected size
labels_tensor = labels_tensor.view(-1)
loss = nn.CrossEntropyLoss()(
logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor
)
perplexity = np.exp(
loss.item()
) # Use .item() to get a Python number from a tensor containing a single value
return {"perplexity": perplexity}
trainer_kwargs["compute_metrics"] = compute_metrics
trainer_cls = (
OneCycleLRSchedulerTrainer
@@ -353,7 +365,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
**data_collator_kwargs,
),
callbacks=callbacks,
compute_metrics=compute_metrics,
**trainer_kwargs,
)