From e7c84254ba830154f34ceaa0dbb34fc582f01b14 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 8 Jul 2023 13:35:24 -0400 Subject: [PATCH] fix perplexity calculation and make it configurable --- src/axolotl/utils/trainer.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 103a38715..09dcf30de 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -330,12 +330,24 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): num_proc=32, ) - def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: - logits, labels = eval_preds - predictions = np.argmax(logits, axis=-1) - loss = nn.CrossEntropyLoss()(predictions, labels) - perplexity = np.exp(loss) - return {"perplexity": perplexity} + if cfg.compute_perplexity_metrics: + + def compute_metrics(eval_preds: EvalPrediction) -> Dict[str, Any]: + logits, labels = eval_preds + # Convert numpy ndarrays to PyTorch tensors + logits_tensor = torch.tensor(logits) + labels_tensor = torch.tensor(labels) + # Adjust labels to match expected size + labels_tensor = labels_tensor.view(-1) + loss = nn.CrossEntropyLoss()( + logits_tensor.view(-1, logits_tensor.size(-1)), labels_tensor + ) + perplexity = np.exp( + loss.item() + ) # Use .item() to get a Python number from a tensor containing a single value + return {"perplexity": perplexity} + + trainer_kwargs["compute_metrics"] = compute_metrics trainer_cls = ( OneCycleLRSchedulerTrainer @@ -353,7 +365,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): **data_collator_kwargs, ), callbacks=callbacks, - compute_metrics=compute_metrics, **trainer_kwargs, )