weighted CEL fixes
This commit is contained in:
@@ -37,9 +37,18 @@ LOG = logging.getLogger("axolotl")
|
|||||||
def weighted_cross_entropy(
|
def weighted_cross_entropy(
|
||||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||||
):
|
):
|
||||||
return (
|
# Flatten the logits, labels, and weights tensors
|
||||||
weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
logits = logits.view(
|
||||||
).sum()
|
-1, logits.size(-1)
|
||||||
|
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
|
||||||
|
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
|
||||||
|
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
|
||||||
|
|
||||||
|
# Compute the unweighted cross entropy loss
|
||||||
|
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||||
|
|
||||||
|
# Apply the weights to the losses and compute their sum
|
||||||
|
return (weights * losses).sum()
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -66,7 +75,6 @@ def create_weighted_mask(labels: torch.Tensor):
|
|||||||
return mask_weights
|
return mask_weights
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||||
logits = (
|
logits = (
|
||||||
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user