fix perplexity calculation and make it configurable

This commit is contained in:
Wing Lian
2023-07-08 13:35:24 -04:00
parent 1d02606934
commit e7c84254ba

View File

@@ -330,13 +330,25 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_proc=32, num_proc=32,
) )
if cfg.compute_perplexity_metrics:
def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]:
logits, labels = eval_preds logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1) # Convert numpy ndarrays to PyTorch tensors
loss = nn.CrossEntropyLoss()(predictions, labels) logits_tensor = torch.tensor(logits)
perplexity = np.exp(loss) labels_tensor = torch.tensor(labels)
# Adjust labels to match expected size
labels_tensor = labels_tensor.view(-1)
loss = nn.CrossEntropyLoss()(
logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor
)
perplexity = np.exp(
loss.item()
) # Use .item() to get a Python number from a tensor containing a single value
return {"perplexity": perplexity} return {"perplexity": perplexity}
trainer_kwargs["compute_metrics"] = compute_metrics
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
@@ -353,7 +365,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
**data_collator_kwargs, **data_collator_kwargs,
), ),
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_metrics,
**trainer_kwargs, **trainer_kwargs,
) )