weighted CEL fixes

This commit is contained in:
Wing Lian
2023-08-02 21:36:39 -04:00
parent c6cc54c7d9
commit bdd34c7400

View File

@@ -37,9 +37,18 @@ LOG = logging.getLogger("axolotl")
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
):
return (
weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")
).sum()
# Flatten the logits, labels, and weights tensors
logits = logits.view(
-1, logits.size(-1)
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
# Compute the unweighted cross entropy loss
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
# Apply the weights to the losses and compute their sum
return (weights * losses).sum()
@torch.jit.script
@@ -66,7 +75,6 @@ def create_weighted_mask(labels: torch.Tensor):
return mask_weights
@torch.jit.script
def trainer_weighted_loss(model_output, labels, shift_labels=True):
logits = (
model_output["logits"] if isinstance(model_output, dict) else model_output[0]