diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d4245563e..2dedfab31 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -37,9 +37,18 @@ LOG = logging.getLogger("axolotl") def weighted_cross_entropy( logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor ): - return ( - weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none") - ).sum() + # Flatten the logits, labels, and weights tensors + logits = logits.view( + -1, logits.size(-1) + ) # logits becomes of shape [batch_size*sequence_length, vocab_size] + labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length] + weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length] + + # Compute the unweighted cross entropy loss + losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none") + + # Apply the weights to the losses and compute their sum + return (weights * losses).sum() @torch.jit.script @@ -66,7 +75,6 @@ def create_weighted_mask(labels: torch.Tensor): return mask_weights -@torch.jit.script def trainer_weighted_loss(model_output, labels, shift_labels=True): logits = ( model_output["logits"] if isinstance(model_output, dict) else model_output[0]