weighted CEL fixes
This commit is contained in:
@@ -37,9 +37,18 @@ LOG = logging.getLogger("axolotl")
|
||||
def weighted_cross_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||
):
|
||||
return (
|
||||
weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||
).sum()
|
||||
# Flatten the logits, labels, and weights tensors
|
||||
logits = logits.view(
|
||||
-1, logits.size(-1)
|
||||
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
|
||||
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
|
||||
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
|
||||
|
||||
# Compute the unweighted cross entropy loss
|
||||
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||
|
||||
# Apply the weights to the losses and compute their sum
|
||||
return (weights * losses).sum()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -66,7 +75,6 @@ def create_weighted_mask(labels: torch.Tensor):
|
||||
return mask_weights
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
logits = (
|
||||
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
|
||||
Reference in New Issue
Block a user