fix perplexity calculation and make it configurable

This commit is contained in:
Wing Lian
2023-07-08 13:35:24 -04:00
parent 1d02606934
commit e7c84254ba

View File

@@ -330,12 +330,24 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_proc=32, num_proc=32,
) )
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: if cfg.compute_perplexity_metrics:
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1) def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
loss = nn.CrossEntropyLoss()(predictions, labels) logits, labels = eval_preds
perplexity = np.exp(loss) # Convert numpy ndarrays to PyTorch tensors
return {"perplexity": perplexity} logits_tensor = torch.tensor(logits)
labels_tensor = torch.tensor(labels)
# Adjust labels to match expected size
labels_tensor = labels_tensor.view(-1)
loss = nn.CrossEntropyLoss()(
logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor
)
perplexity = np.exp(
loss.item()
) # Use .item() to get a Python number from a tensor containing a single value
return {"perplexity": perplexity}
trainer_kwargs["compute_metrics"] = compute_metrics
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
@@ -353,7 +365,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
**data_collator_kwargs, **data_collator_kwargs,
), ),
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_metrics,
**trainer_kwargs, **trainer_kwargs,
) )