From a6fefa8885e801de8f0485ce546b331935570aec Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 30 Nov 2023 22:30:28 +0100 Subject: [PATCH] Initial refactor [untested] --- src/axolotl/monkeypatch/flash_module.py | 385 ++++++++++++++++ src/axolotl/monkeypatch/fused_module.py | 94 ++++ .../monkeypatch/llama_attn_hijack_flash.py | 434 +----------------- .../monkeypatch/mistral_attn_hijack_flash.py | 298 +----------- 4 files changed, 488 insertions(+), 723 deletions(-) create mode 100644 src/axolotl/monkeypatch/flash_module.py create mode 100644 src/axolotl/monkeypatch/fused_module.py diff --git a/src/axolotl/monkeypatch/flash_module.py b/src/axolotl/monkeypatch/flash_module.py new file mode 100644 index 000000000..443523836 --- /dev/null +++ b/src/axolotl/monkeypatch/flash_module.py @@ -0,0 +1,385 @@ +import torch +import logging +import warnings +from einops import rearrange +import torch.nn.functional as F +from typing import Optional, Tuple +from flash_attn.bert_padding import pad_input, unpad_input + +from axolotl.monkeypatch.fused_module import FusedAttention + +try: + from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports + flash_attn_kvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + ) +except ImportError: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, + ) + +LOG = logging.getLogger("axolotl") + +def flashattn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + *args, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + if not hasattr(self, "pretraining_tp"): + self.pretraining_tp = 1 + + if self.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + if isinstance(self, FusedAttention): + query_states, key_states, value_states = self.qkv_proj(hidden_states).split( + self.out_features, dim=-1 + ) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = self.apply_rotary_fn( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + use_sliding_windows = ( + hasattr(self.config, "sliding_window") is not None + and kv_seq_len > self.config.sliding_window + ) + + if use_sliding_windows: + window_size = (self.config.sliding_window, self.config.sliding_window) + else: + window_size = (-1, -1) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + if ( + hasattr(self.config, "sliding_window") + and kv_seq_len > self.config.sliding_window + ): + slicing_tokens = kv_seq_len - self.config.sliding_window + + past_key = past_key_value[0] + past_value = past_key_value[1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + past_key_value = (past_key, past_value) if use_cache else None + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = self.repeat_kv_fn(key_states, self.num_key_value_groups) + value_states = self.repeat_kv_fn(value_states, self.num_key_value_groups) + + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + # + # flash-attn v2 start + # + + if self.training: + # during training q,k,v always have same seqlen + assert key_states.shape == query_states.shape + is_causal = True + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = key_states.shape == query_states.shape + + dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) + + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: + # special handling using sample packing + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + qkv = rearrange(qkv, "b s ... -> (b s) ...") + + output = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + window_size=window_size, + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + elif query_states.shape == key_states.shape: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( + query_states, + key_states, + value_states, + qkvpacked=True, + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask=attention_mask, + query_padding_mask=attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None, + ) + output_unpad = flash_attn_varlen_qkvpacked_func( + qkv_unpad, + cu_seqlens_q, + max_seqlen_q, + dropout_p=dropout_rate, + softmax_scale=None, + causal=is_causal, + window_size=window_size, + ) + output = output_pad_fn(output_unpad) + else: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if attention_mask is None or attention_mask.all().item(): + output = flash_attn_kvpacked_func( + query_states, + torch.stack([key_states, value_states], 2), + dropout_p=dropout_rate, + causal=is_causal, + window_size=window_size, + ) + else: + ( # pylint: disable=unbalanced-tuple-unpacking + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + _, + _, + output_pad_fn, + ) = generate_qkv( + query_states, + key_states, + value_states, + kvpacked=True, + key_padding_mask=attention_mask, + query_padding_mask=attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None, + ) + if q_unpad.dtype != kv_unpad.dtype: + kv_unpad = kv_unpad.to(q_unpad.dtype) + output_unpad = flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=dropout_rate, + softmax_scale=None, + causal=is_causal, + window_size=window_size, + ) + output = output_pad_fn(output_unpad) + + attn_output = output + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = rearrange(attn_output, "b s h d -> b s (h d)") + + # + # flash-attn v2 end + # + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.pretraining_tp, dim=1 + ) + attn_output = sum( + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp) + ) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, +): # pylint: disable=invalid-name,unnecessary-lambda-assignment + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( + q, query_padding_mask + ) + + output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 + output_unpad, indices_q, batch_size, seqlen_q + ) + + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + max_seqlen_q = seqlen_q + + output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) + + if kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + return ( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + kv, + output_pad_fn, + ) + + return ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + ) diff --git a/src/axolotl/monkeypatch/fused_module.py b/src/axolotl/monkeypatch/fused_module.py new file mode 100644 index 000000000..bb94d8015 --- /dev/null +++ b/src/axolotl/monkeypatch/fused_module.py @@ -0,0 +1,94 @@ +import torch +from typing import List +from xformers.ops import SwiGLU +from axolotl.monkeypatch.utils import set_module_name +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaMLP, +) + +# TODO: Generalize to other attention modules +class FusedAttention(LlamaAttention): + """ + Fused QKV Attention layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + q: torch.nn.Linear, # pylint: disable=invalid-name + k: torch.nn.Linear, # pylint: disable=invalid-name + v: torch.nn.Linear, # pylint: disable=invalid-name + o: torch.nn.Linear, # pylint: disable=invalid-name + ): + super().__init__(config) + self.config = config + self.init_device = next(iter(q.state_dict().values())).device + + # define equivalent fused qkv projection + self.out_features: List[int] = [q.out_features, k.out_features, v.out_features] + self.qkv_proj = torch.nn.Linear( + q.in_features, sum(self.out_features), device=self.init_device, bias=False + ) + self.o_proj = o + + # overwrite initialized weights with pretrained weights + self.qkv_proj.weight.data = torch.cat( + (q.weight.data, k.weight.data, v.weight.data), dim=0 + ) + + def _post_training(self, model, name): + q_proj, k_proj, v_proj = torch.split( + self.qkv_proj.weight.data, self.out_features, dim=0 + ) + + new_attn = LlamaAttention(self.config) + new_attn.q_proj.weight.data = q_proj + new_attn.k_proj.weight.data = k_proj + new_attn.v_proj.weight.data = v_proj + new_attn.o_proj.weight.data = self.o_proj.weight.data + + set_module_name(model, name, new_attn) + + +class FusedMLP(torch.nn.Module): + """ + Fused MLP layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + gate_proj: torch.nn.Linear, + up_proj: torch.nn.Linear, + down_proj: torch.nn.Linear, + ): + super().__init__() + self.config = config + self.swiglu = SwiGLU( + in_features=config.hidden_size, + hidden_features=config.intermediate_size, + bias=False, + _pack_weights=True, + ) + # overwrite initialized weights with pretrained weights + self.swiglu.w12.weight.data = torch.cat( + (gate_proj.weight.data, up_proj.weight.data), dim=0 + ) + self.swiglu.w3.weight.data = down_proj.weight.data + + def _post_training(self, model, name): + w1, w2 = torch.split( # pylint: disable=invalid-name + self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 + ) + + # Assign the split weights back to the original layers + new_mlp = LlamaMLP(self.config) + new_mlp.gate_proj.weight.data = w1 + new_mlp.up_proj.weight.data = w2 + new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data + + set_module_name(model, name, new_mlp) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return self.swiglu(x) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f380c3f2a..92a383b1b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -3,15 +3,12 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py import logging -import warnings from functools import partial from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import transformers -from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import ( @@ -19,26 +16,16 @@ from transformers.models.llama.modeling_llama import ( ) from transformers.models.llama.modeling_llama import ( LlamaMLP, +) + +from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) -from xformers.ops import SwiGLU from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name - -try: - from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - ) -except ImportError: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, - ) - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, - ) +from axolotl.monkeypatch.fused_module import FusedAttention, FusedMLP +from axolotl.monkeypatch.flash_module import flashattn_forward LOG = logging.getLogger("axolotl") @@ -75,6 +62,8 @@ def replace_llama_attn_with_flash_attn( _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + transformers.models.llama.modeling_llama.LlamaAttention.apply_rotary_fn = apply_rotary_pos_emb + transformers.models.llama.modeling_llama.LlamaAttention.repeat_kv_fn = repeat_kv if packed: transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = ( @@ -114,91 +103,6 @@ def replace_llama_attn_with_flash_attn( ) -class FusedAttention(LlamaAttention): - """ - Fused QKV Attention layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - q: torch.nn.Linear, # pylint: disable=invalid-name - k: torch.nn.Linear, # pylint: disable=invalid-name - v: torch.nn.Linear, # pylint: disable=invalid-name - o: torch.nn.Linear, # pylint: disable=invalid-name - ): - super().__init__(config) - self.config = config - self.init_device = next(iter(q.state_dict().values())).device - - # define equivalent fused qkv projection - self.out_features: List[int] = [q.out_features, k.out_features, v.out_features] - self.qkv_proj = torch.nn.Linear( - q.in_features, sum(self.out_features), device=self.init_device, bias=False - ) - self.o_proj = o - - # overwrite initialized weights with pretrained weights - self.qkv_proj.weight.data = torch.cat( - (q.weight.data, k.weight.data, v.weight.data), dim=0 - ) - - def _post_training(self, model, name): - q_proj, k_proj, v_proj = torch.split( - self.qkv_proj.weight.data, self.out_features, dim=0 - ) - - new_attn = LlamaAttention(self.config) - new_attn.q_proj.weight.data = q_proj - new_attn.k_proj.weight.data = k_proj - new_attn.v_proj.weight.data = v_proj - new_attn.o_proj.weight.data = self.o_proj.weight.data - - set_module_name(model, name, new_attn) - - -class FusedMLP(torch.nn.Module): - """ - Fused MLP layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - ): - super().__init__() - self.config = config - self.swiglu = SwiGLU( - in_features=config.hidden_size, - hidden_features=config.intermediate_size, - bias=False, - _pack_weights=True, - ) - # overwrite initialized weights with pretrained weights - self.swiglu.w12.weight.data = torch.cat( - (gate_proj.weight.data, up_proj.weight.data), dim=0 - ) - self.swiglu.w3.weight.data = down_proj.weight.data - - def _post_training(self, model, name): - w1, w2 = torch.split( # pylint: disable=invalid-name - self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 - ) - - # Assign the split weights back to the original layers - new_mlp = LlamaMLP(self.config) - new_mlp.gate_proj.weight.data = w1 - new_mlp.up_proj.weight.data = w2 - new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data - - set_module_name(model, name, new_mlp) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return self.swiglu(x) - # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask @@ -213,330 +117,6 @@ def _prepare_decoder_attention_mask( return attention_mask -def flashattn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel - - attention_mask: [bsz, q_len] - """ - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - if isinstance(self, FusedAttention): - query_states, key_states, value_states = self.qkv_proj(hidden_states).split( - self.out_features, dim=-1 - ) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # flash-attn v2 start - # - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - # - # flash-attn v2 end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - def llama_model_forward( self, diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index e31864b83..1bf9851d7 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -36,6 +36,8 @@ def replace_mistral_attn_with_flash_attn( transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( flashattn_forward ) + transformers.models.mistral.modeling_mistral.MistralAttention.apply_rotary_fn = apply_rotary_pos_emb + transformers.models.mistral.modeling_mistral.MistralAttention.repeat_kv_fn = repeat_kv if packed: transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( MistralDecoderLayer @@ -115,302 +117,6 @@ def _prepare_decoder_attention_mask( return attention_mask -def flashattn_forward( - self: OriginalMistralAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - use_sliding_windows = ( - hasattr(self.config, "sliding_window") is not None - and kv_seq_len > self.config.sliding_window - ) - - if use_sliding_windows: - window_size = (self.config.sliding_window, self.config.sliding_window) - else: - window_size = (-1, -1) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - if ( - hasattr(self.config, "sliding_window") - and kv_seq_len > self.config.sliding_window - ): - slicing_tokens = kv_seq_len - self.config.sliding_window - - past_key = past_key_value[0] - past_value = past_key_value[1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - past_key_value = (past_key, past_value) if use_cache else None - - if past_key_value is not None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - window_size=window_size, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - window_size=window_size, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - def mistral_model_forward( self, input_ids: torch.LongTensor = None,