diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 645ef0231..81c804309 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -696,17 +696,14 @@ def mistral_causallm_forward( loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) + shift_logits = logits + if not hasattr(self, "extra_ignored_labels"): + self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device) + + shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) shift_labels = shift_labels.to(shift_logits.device) # FAST CROSS ENTROPY - if self.config.vocab_size > 65536: - raise Exception("Fast cross entropy is only compatible with vocab_size <= 65536") loss = fast_cross_entropy_loss(shift_logits, shift_labels) if not return_dict: