From 538c004080b5213f137fcc52fad464c2e394431e Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Wed, 6 Dec 2023 20:26:25 +0000 Subject: [PATCH] Fix shapes --- .../monkeypatch/mistral_attn_hijack_flash.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 645ef0231..81c804309 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -696,17 +696,14 @@ def mistral_causallm_forward( loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) + shift_logits = logits + if not hasattr(self, "extra_ignored_labels"): + self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device) + + shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) shift_labels = shift_labels.to(shift_logits.device) # FAST CROSS ENTROPY - if self.config.vocab_size > 65536: - raise Exception("Fast cross entropy is only compatible with vocab_size <= 65536") loss = fast_cross_entropy_loss(shift_logits, shift_labels) if not return_dict: