Fix shapes
This commit is contained in:
@@ -696,17 +696,14 @@ def mistral_causallm_forward(
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
shift_logits = logits
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
if not hasattr(self, "extra_ignored_labels"):
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||||
# Flatten the tokens
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
|
||||||
# FAST CROSS ENTROPY
|
# FAST CROSS ENTROPY
|
||||||
if self.config.vocab_size > 65536:
|
|
||||||
raise Exception("Fast cross entropy is only compatible with vocab_size <= 65536")
|
|
||||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user