From add3b139ed7967ad9f962115848251cd7e8ebf14 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 6 Dec 2023 20:17:42 +0100 Subject: [PATCH] Mistral with fast cross entropy --- src/axolotl/monkeypatch/cross_entropy.py | 154 ++++++++++++++++++ .../monkeypatch/mistral_attn_hijack_flash.py | 81 ++++++++- 2 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/monkeypatch/cross_entropy.py diff --git a/src/axolotl/monkeypatch/cross_entropy.py b/src/axolotl/monkeypatch/cross_entropy.py new file mode 100644 index 000000000..9826eebc9 --- /dev/null +++ b/src/axolotl/monkeypatch/cross_entropy.py @@ -0,0 +1,154 @@ +# Adapted from Unsloth +# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py + +import triton +import triton.language as tl +import torch +from .utils import calculate_settings + +@triton.jit +def _cross_entropy_forward(logits_ptr, logits_row_stride, + loss_ptr, + lse_ptr, + labels_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr,): + """ + Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] + Pi = exp(xi) / sum(exp(xi)) + CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ] + = -y [ x - log[sum(exp(x))] ] + = y * (log[sum(exp(x))] - x) + If y == 0: CE_i = 0 + If y == 1: CE_i = logsumexp - x + """ + row_idx = tl.program_id(0) + logits_ptr += row_idx * logits_row_stride + loss_ptr += row_idx + lse_ptr += row_idx + labels_ptr += row_idx + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # TODO: Fixup int32 locations to int64 + label_idx = tl.load(labels_ptr).to(tl.int32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + max_logits = tl.max(logits, 0) + # Maximum stops overflow + lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits + tl.store(lse_ptr, lse) + + if label_idx != -100: + logits_label = tl.load(logits_ptr + label_idx).to(tl.float32) + loss = lse - logits_label + else: + loss = 0.0 + tl.store(loss_ptr, loss) +pass + + +@triton.jit +def _cross_entropy_backward(logits_ptr, logits_row_stride, + dloss_ptr, dloss_row_stride, + lse_ptr, + labels_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr,): + """ + CE_i = -y log(P) = y * (log[sum(exp(x))] - x) + dC/dx = d/dx (y * log[sum(exp(x))] - x * y) + + From https://en.wikipedia.org/wiki/LogSumExp + d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x) + + dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y) + dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick + dC/dx = y * exp[x - logsumexp] - d/dx (x * y) + + If y == 0: dC/dx = 0 + If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1 + If y == 1 and x != label: dC/dx = exp[x - logsumexp] + """ + row_idx = tl.program_id(0) + logits_ptr += row_idx * logits_row_stride + dloss_ptr += row_idx * dloss_row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + # TODO: Fixup int32 locations to int64 + label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) + + if label_idx != -100: + dloss = tl.load(dloss_ptr) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32) + lse = tl.load(lse_ptr + row_idx) + probs = tl.exp(logits - lse) + + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask) + + + +class CrossEntropyLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, labels): + n_rows, n_cols = logits.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda") + + _cross_entropy_forward[(n_rows,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + n_cols, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(logits, logsumexp, labels) + return losses + pass + + @staticmethod + def backward(ctx, dlosses): + logits, logsumexp, labels = ctx.saved_tensors + n_rows, n_cols = logits.shape + + _cross_entropy_backward[(n_rows,)]( + logits, logits.stride(0), + dlosses, dlosses.stride(0), + logsumexp, + labels, + n_cols, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) + return logits, None, None, + pass +pass + + +def fast_cross_entropy_loss(logits, labels): + """ + Arguments: + logits: (batch, seq_len, vocab_size) + labels: (batch, seq_len,) + Returns: + losses: float + """ + batch, seq_len, d = logits.shape + assert(labels.shape == (batch, seq_len)) + + loss = CrossEntropyLoss.apply( + logits.view(batch*seq_len, d), + labels.view(-1), + ) + n_items = torch.count_nonzero(labels != -100) + return loss.sum() / n_items +pass \ No newline at end of file diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index e31864b83..645ef0231 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -13,16 +13,20 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, ) -from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.mistral.modeling_mistral import ( MistralAttention as OriginalMistralAttention, ) from transformers.models.mistral.modeling_mistral import ( MistralDecoderLayer as OriginalMistralDecoderLayer, ) +from transformers.models.mistral.modeling_mistral import ( + MistralForCausalLM as OriginalMistralForCausalLM, +) from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss LOG = logging.getLogger("axolotl.monkeypatch.mistral") @@ -36,6 +40,9 @@ def replace_mistral_attn_with_flash_attn( transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( flashattn_forward ) + transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = ( + mistral_causallm_forward + ) if packed: transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( MistralDecoderLayer @@ -641,3 +648,75 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer): outputs += (present_key_value,) return outputs + +def mistral_causallm_forward( + self: OriginalMistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, **kwargs +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + + # FAST CROSS ENTROPY + if self.config.vocab_size > 65536: + raise Exception("Fast cross entropy is only compatible with vocab_size <= 65536") + loss = fast_cross_entropy_loss(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file