From 97e86c6d47612f3465eaa73f070a231bd9550602 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 6 Aug 2025 08:02:39 -0400 Subject: [PATCH] drop old patches and code that are no longer needed (#3007) [skip ci] --- .runpod/README.md | 1 - .runpod/src/config/config.yaml | 2 - examples/archived/stablelm-2/1.6b/fft.yml | 1 - examples/llama-2/fft_optimized.yml | 1 - examples/llama-2/lisa.yml | 1 - src/axolotl/loaders/patch_manager.py | 21 +- .../monkeypatch/llama_attn_hijack_flash.py | 657 +----------------- .../monkeypatch/mistral_attn_hijack_flash.py | 640 ----------------- src/axolotl/utils/schemas/config.py | 6 - src/axolotl/utils/schemas/validation.py | 6 +- tests/e2e/patched/test_fused_llama.py | 1 - 11 files changed, 7 insertions(+), 1330 deletions(-) diff --git a/.runpod/README.md b/.runpod/README.md index 2d24f1e5c..8042f4f91 100644 --- a/.runpod/README.md +++ b/.runpod/README.md @@ -185,7 +185,6 @@ datasets: | `flash_attention` | `false` | Use flash attention | | `flash_attn_cross_entropy` | `false` | Flash attention cross entropy | | `flash_attn_rms_norm` | `false` | Flash attention RMS norm | -| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations | | `flash_attn_fuse_mlp` | `false` | Fuse MLP operations | | `sdp_attention` | `false` | Use scaled dot product | | `s2_attention` | `false` | Use shifted sparse attention | diff --git a/.runpod/src/config/config.yaml b/.runpod/src/config/config.yaml index 2a89971fb..f482a7331 100644 --- a/.runpod/src/config/config.yaml +++ b/.runpod/src/config/config.yaml @@ -296,7 +296,6 @@ # flash_attention: # flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only # flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only -# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation # flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # # Whether to use scaled-dot-product attention # # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION} flash_attention: ${FLASH_ATTENTION} flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY} flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM} -flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV} flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP} sdp_attention: ${SDP_ATTENTION} s2_attention: ${S2_ATTENTION} diff --git a/examples/archived/stablelm-2/1.6b/fft.yml b/examples/archived/stablelm-2/1.6b/fft.yml index d608bc66f..585888f43 100644 --- a/examples/archived/stablelm-2/1.6b/fft.yml +++ b/examples/archived/stablelm-2/1.6b/fft.yml @@ -47,7 +47,6 @@ logging_steps: 1 flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -flash_attn_fuse_qkv: false flash_attn_fuse_mlp: true warmup_ratio: 0.1 diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 678806473..ea119348e 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -45,7 +45,6 @@ logging_steps: 1 flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -flash_attn_fuse_qkv: false flash_attn_fuse_mlp: true warmup_ratio: 0.1 diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index 7b92b72e1..d21c01a49 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -49,7 +49,6 @@ logging_steps: 1 flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -flash_attn_fuse_qkv: false flash_attn_fuse_mlp: true warmup_ratio: 0.1 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index e16f03649..4273f3cce 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -348,31 +348,21 @@ class PatchManager: patch_self_attn_lora() - def _patch_llama_flash_attention(self, packed=False): + def _patch_llama_flash_attention(self): """Apply Flash Attention patches for LLaMA models.""" from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - if packed: - if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=True, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.s2_attention: + if self.cfg.s2_attention: LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( - packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: replace_llama_attn_with_flash_attn( - packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, ) @@ -403,7 +393,7 @@ class PatchManager: and self.cfg.sample_packing ): if self.cfg.flash_attention: - self._patch_llama_flash_attention(packed=self.cfg.sample_packing) + self._patch_llama_flash_attention() elif self.cfg.xformers_attention: self._patch_llama_xformers_attention() elif self.cfg.sample_packing: @@ -426,17 +416,12 @@ class PatchManager: from axolotl.monkeypatch.llama_attn_hijack_flash import ( is_xformers_swiglu_available, replace_llama_mlp_with_swiglu, - replace_llama_qkv_with_fused, ) if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): LOG.info("Patching with SwiGLU...") replace_llama_mlp_with_swiglu(model) - if self.cfg.flash_attn_fuse_qkv: - LOG.info("Patching with fused QKV...") - replace_llama_qkv_with_fused(model) - def _apply_unsloth_patches(self, model): """Apply unsloth optimization patches.""" if self.cfg.unsloth_lora_mlp: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 70e36714c..1316b5374 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -3,39 +3,26 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py import warnings -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch -import torch.nn.functional as F import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import ( - LlamaAttention, -) -from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer as OriginalLlamaDecoderLayer, -) from transformers.models.llama.modeling_llama import ( LlamaMLP, apply_rotary_pos_emb, repeat_kv, ) -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name +from axolotl.monkeypatch.utils import set_module_name from axolotl.utils.logging import get_logger try: from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, ) except ImportError: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, - ) from flash_attn.flash_attn_interface import ( flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, ) @@ -82,19 +69,6 @@ def replace_llama_mlp_with_swiglu(model): set_module_name(model, name, mlp) -def replace_llama_qkv_with_fused(model): - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - qkv = FusedAttention( - module.config, - module.q_proj, - module.k_proj, - module.v_proj, - module.o_proj, - ) - set_module_name(model, name, qkv) - - def patch_fa_llama_cross_entropy(): LOG.info( "patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy" @@ -142,7 +116,6 @@ def patch_llama_rms_norm(): def replace_llama_attn_with_flash_attn( - packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, rms_norm: Optional[bool] = False, use_shifted_sparse_attn: Optional[bool] = False, @@ -154,16 +127,6 @@ def replace_llama_attn_with_flash_attn( transformers.models.llama.modeling_llama.LlamaAttention.forward = ( flashattn_forward_with_s2attn ) - else: - transformers.models.llama.modeling_llama.LlamaAttention.forward = ( - flashattn_forward - ) - - if packed: - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - transformers.models.llama.modeling_llama.LlamaModel.forward = ( - llama_model_forward - ) # skip only if explicitly disabled if cross_entropy: @@ -174,49 +137,6 @@ def replace_llama_attn_with_flash_attn( patch_llama_rms_norm() -class FusedAttention(LlamaAttention): - """ - Fused QKV Attention layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - q: torch.nn.Linear, # pylint: disable=invalid-name - k: torch.nn.Linear, # pylint: disable=invalid-name - v: torch.nn.Linear, # pylint: disable=invalid-name - o: torch.nn.Linear, # pylint: disable=invalid-name - ): - super().__init__(config) - self.config = config - self.init_device = next(iter(q.state_dict().values())).device - - # define equivalent fused qkv projection - self.out_features: List[int] = [q.out_features, k.out_features, v.out_features] - self.qkv_proj = torch.nn.Linear( - q.in_features, sum(self.out_features), device=self.init_device, bias=False - ) - self.o_proj = o - - # overwrite initialized weights with pretrained weights - self.qkv_proj.weight.data = torch.cat( - (q.weight.data, k.weight.data, v.weight.data), dim=0 - ) - - def _post_training(self, model, name): - q_proj, k_proj, v_proj = torch.split( - self.qkv_proj.weight.data, self.out_features, dim=0 - ) - - new_attn = LlamaAttention(self.config) - new_attn.q_proj.weight.data = q_proj - new_attn.k_proj.weight.data = k_proj - new_attn.v_proj.weight.data = v_proj - new_attn.o_proj.weight.data = self.o_proj.weight.data - - set_module_name(model, name, new_attn) - - # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( @@ -355,576 +275,3 @@ def flashattn_forward_with_s2attn( .reshape(bsz, q_len, nheads, self.head_dim) ) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value - - -def flashattn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel - - attention_mask: [bsz, q_len] - """ - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - if isinstance(self, FusedAttention): - query_states, key_states, value_states = self.qkv_proj(hidden_states).split( - self.out_features, dim=-1 - ) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - cos, sin = self.rotary_emb(value_states, position_ids=position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # flash-attn v2 start - # - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - # - # flash-attn v2 end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - def output_pad_fn(output_unpad): - return pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - def output_pad_fn(output_unpad): - return rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - -def llama_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[ # pylint: disable=unused-argument - torch.LongTensor - ] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - padding_mask = None - else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None - - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - transformers.logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module( - *inputs, - ) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - padding_mask, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaDecoderLayer(OriginalLlamaDecoderLayer): - """ - patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 3fc22917f..e1be424a3 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -3,53 +3,14 @@ # pylint: disable=duplicate-code from functools import partial -from typing import List, Optional, Tuple, Union -import torch import transformers -from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, -) -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - MistralAttention as OriginalMistralAttention, -) -from transformers.models.mistral.modeling_mistral import ( - MistralDecoderLayer as OriginalMistralDecoderLayer, -) -from transformers.models.mistral.modeling_mistral import ( - apply_rotary_pos_emb, - repeat_kv, -) -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -def replace_mistral_attn_with_flash_attn( - packed: Optional[bool] = False, -): - transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access - _prepare_decoder_attention_mask - ) - transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( - flashattn_forward - ) - if packed: - transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( - MistralDecoderLayer - ) - transformers.models.mistral.modeling_mistral.MistralModel.forward = ( - mistral_model_forward - ) - - def patch_mistral_cross_entropy(): from flash_attn.losses.cross_entropy import CrossEntropyLoss @@ -57,604 +18,3 @@ def patch_mistral_cross_entropy(): transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial( CrossEntropyLoss, inplace_backward=True ) - - -@torch.jit.script -def _make_sliding_window_causal_mask( - bsz: int, - tgt_len: int, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: int = 4096, -): - """ - Make causal mask used for sliding window attention - """ - tensor = torch.full( - (tgt_len, tgt_len), - fill_value=1, - device=device, - ) - mask = torch.tril(tensor, diagonal=0) - # make the mask banded to account for sliding window - # NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1 - mask = torch.triu(mask, diagonal=-sliding_window + 1) - mask = torch.log(mask).to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -# Disable the transformation of the attention mask in LlamaModel as the flash attention -# requires the attention mask to be the same as the key_padding_mask -def _prepare_decoder_attention_mask( - self, - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - sliding_window, -): # pylint: disable=unused-argument - # [bsz, seq_len] - if attention_mask is None or sliding_window is None: - return attention_mask - - # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. - # Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled. - if input_shape[-1] > 1 and attention_mask.shape[0] == 1: - sliding_window_mask = _make_sliding_window_causal_mask( - bsz=input_shape[0], - tgt_len=input_shape[1], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - attention_mask = attention_mask + sliding_window_mask - else: - LOG.info("skipping sliding window mask, not broadcastable with attention mask") - - return attention_mask - - -def flashattn_forward( - self: OriginalMistralAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, position_ids=position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - use_sliding_windows = ( - getattr(self.config, "sliding_window") is not None - and kv_seq_len > self.config.sliding_window - ) - - if use_sliding_windows: - window_size = (self.config.sliding_window, self.config.sliding_window) - else: - window_size = (-1, -1) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - if ( - hasattr(self.config, "sliding_window") - and kv_seq_len > self.config.sliding_window - ): - slicing_tokens = kv_seq_len - self.config.sliding_window - - past_key = past_key_value[0] - past_value = past_key_value[1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - past_key_value = (past_key, past_value) if use_cache else None - - if past_key_value is not None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - window_size=window_size, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - window_size=window_size, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - def output_pad_fn(output_unpad): - return pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - def output_pad_fn(output_unpad): - return rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - -def mistral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[ # pylint: disable=unused-argument - torch.LongTensor - ] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - transformers.logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = ( - self._gradient_checkpointing_func( # pylint: disable=protected-access - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - cu_seqlens, - max_seqlen, - ) - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MistralDecoderLayer(OriginalMistralDecoderLayer): - """ - patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1d089ba41..beaee57c9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -531,12 +531,6 @@ class AxolotlInputConfig( "description": "Whether to use flash-attention rms norm implementation - advanced use only" }, ) - flash_attn_fuse_qkv: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Whether to fuse QKV into a single operation" - }, - ) flash_attn_fuse_mlp: bool | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 61eec65d5..e15adf077 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -577,9 +577,7 @@ class LoRAValidationMixin: @model_validator(mode="after") def check_fused_lora(self): - if self.adapter in ["lora", "qlora"] and ( - self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp - ): + if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with LoRA/QLoRA") return self @@ -1184,7 +1182,7 @@ class ComplexValidationMixin: "ReLoRA is not compatible with the one_cycle scheduler" ) - if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: + if self.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with ReLoRA") return self diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index a3fe591ee..f0c4f155f 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -29,7 +29,6 @@ class TestFusedLlama(unittest.TestCase): "base_model": "HuggingFaceTB/SmolLM2-135M", "flash_attention": True, "pad_to_sequence_len": True, - "flash_attn_fuse_qkv": True, "flash_attn_fuse_mlp": True, "sample_packing": True, "sequence_len": 1024,