Fix shapes
This commit is contained in:
@@ -696,17 +696,14 @@ def mistral_causallm_forward(
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
# FAST CROSS ENTROPY
|
||||
if self.config.vocab_size > 65536:
|
||||
raise Exception("Fast cross entropy is only compatible with vocab_size <= 65536")
|
||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
Reference in New Issue
Block a user