Fix shapes

This commit is contained in:
Casper Hansen
2023-12-06 20:26:25 +00:00
parent add3b139ed
commit 538c004080

View File

@@ -696,17 +696,14 @@ def mistral_causallm_forward(
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
shift_labels = shift_labels.to(shift_logits.device)
# FAST CROSS ENTROPY
if self.config.vocab_size > 65536:
raise Exception("Fast cross entropy is only compatible with vocab_size <= 65536")
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
if not return_dict: