From bdd34c74001e7e63248d35c1cd51846fdc72df48 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 2 Aug 2023 21:36:39 -0400 Subject: [PATCH] weighted CEL fixes --- src/axolotl/utils/trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d4245563e..2dedfab31 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -37,9 +37,18 @@ LOG = logging.getLogger("axolotl") def weighted_cross_entropy( logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor ): - return ( - weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none") - ).sum() + # Flatten the logits, labels, and weights tensors + logits = logits.view( + -1, logits.size(-1) + ) # logits becomes of shape [batch_size*sequence_length, vocab_size] + labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length] + weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length] + + # Compute the unweighted cross entropy loss + losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none") + + # Apply the weights to the losses and compute their sum + return (weights * losses).sum() @torch.jit.script @@ -66,7 +75,6 @@ def create_weighted_mask(labels: torch.Tensor): return mask_weights -@torch.jit.script def trainer_weighted_loss(model_output, labels, shift_labels=True): logits = ( model_output["logits"] if isinstance(model_output, dict) else model_output[0]