From 06edf175ace06ad11872427f812c94cef6158bf6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Fri, 18 Aug 2023 09:54:16 -0700 Subject: [PATCH] standardize attn hijack patches (#381) * split sdp attn into its own patch * sync xformers patch to follow shared format and be diffable * update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests * speed up flash-attn inference * fix patch to check position ids and don't use multipack for evals * copy LlamaModel.forward and LlamaDecoderLayer.forward into monkeypatch * update forwards so we only calculate cu_seqlens once * enable eval dataloader using multipack again * fix the patch to work properly and work with FSDP --------- Co-authored-by: Wing Lian --- .../monkeypatch/llama_attn_hijack_flash.py | 688 ++++++++++++++---- .../monkeypatch/llama_attn_hijack_sdp.py | 140 ++++ .../monkeypatch/llama_attn_hijack_xformers.py | 219 ++---- src/axolotl/utils/models.py | 6 +- 4 files changed, 750 insertions(+), 303 deletions(-) create mode 100644 src/axolotl/monkeypatch/llama_attn_hijack_sdp.py diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6cdd50934..14056fa54 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -2,142 +2,47 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py -from typing import Optional, Tuple +import warnings +from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer as OriginalLlamaDecoderLayer, +) +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids try: - from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports + flash_attn_kvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + ) except ImportError: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, + ) from flash_attn.flash_attn_interface import ( flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, ) -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids - - -def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel - - attention_mask: [bsz, q_len] - """ - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - query_states = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) +def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access + _prepare_decoder_attention_mask ) - key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - assert past_key_value is None, "past_key_value is not supported" - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - assert not output_attentions, "output_attentions is not supported" - assert not use_cache, "use_cache is not supported" - - # Flash attention codes from - # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py - - # transform the data into the format required by flash attention - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask = attention_mask - - if key_padding_mask is None: - qkv = rearrange(qkv, "b s ... -> (b s) ...") - max_s = q_len - cu_q_lens = torch.arange( - 0, - (bsz + 1) * q_len, - step=q_len, - dtype=torch.int32, - device=qkv.device, + transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + if packed: + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer + transformers.models.llama.modeling_llama.LlamaModel.forward = ( + llama_model_forward ) - output = flash_attn_varlen_qkvpacked_func( - qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif attention_mask.shape[0] == 1: - # special handling using sample packing - qkv = rearrange(qkv, "b s ... -> (b s) ...") - cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) - cu_q_lens = cu_q_lens.squeeze() - - output = flash_attn_varlen_qkvpacked_func( - qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - else: - nheads = qkv.shape[-2] - - # pylint: disable=invalid-name - x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) - x_unpad = rearrange( - x_unpad, - "nnz (three h d) -> nnz three h d", - three=3, - h=nheads, - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - x_unpad, - cu_q_lens, - max_s, - 0.0, - softmax_scale=None, - causal=True, - ) - output = rearrange( - pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), - indices, - bsz, - q_len, - ), - "b s (h d) -> b s h d", - h=nheads, - ) - - return ( - self.o_proj(rearrange(output, "b s h d -> b s (h d)")), - None, - None, - ) # Disable the transformation of the attention mask in LlamaModel as the flash attention @@ -153,8 +58,541 @@ def _prepare_decoder_attention_mask( return attention_mask -def replace_llama_attn_with_flash_attn(): - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access - _prepare_decoder_attention_mask +def flashattn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + if not hasattr(self, "pretraining_tp"): + self.pretraining_tp = 1 + + if self.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids ) - transformers.models.llama.modeling_llama.LlamaAttention.forward = forward + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + # + # flash-attn v2 start + # + + if self.training: + # during training q,k,v always have same seqlen + assert key_states.shape == query_states.shape + is_causal = True + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = past_key_value is not None + + if cu_seqlens is not None and max_seqlen is not None: + # special handling using sample packing + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + qkv = rearrange(qkv, "b s ... -> (b s) ...") + + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + elif query_states.shape == key_states.shape: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( + query_states, + key_states, + value_states, + qkvpacked=True, + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask=attention_mask, + query_padding_mask=attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None, + ) + output_unpad = flash_attn_varlen_qkvpacked_func( + qkv_unpad, + cu_seqlens_q, + max_seqlen_q, + 0.0, + softmax_scale=None, + causal=is_causal, + ) + output = output_pad_fn(output_unpad) + else: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if attention_mask is None or attention_mask.all().item(): + output = flash_attn_kvpacked_func( + query_states, + torch.stack([key_states, value_states], 2), + causal=is_causal, + ) + else: + ( # pylint: disable=unbalanced-tuple-unpacking + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + _, + _, + output_pad_fn, + ) = generate_qkv( + query_states, + key_states, + value_states, + kvpacked=True, + key_padding_mask=attention_mask, + query_padding_mask=attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale=None, + causal=is_causal, + ) + output = output_pad_fn(output_unpad) + + attn_output = output + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = rearrange(attn_output, "b s h d -> b s (h d)") + + # + # flash-attn v2 end + # + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.pretraining_tp, dim=1 + ) + attn_output = sum( + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp) + ) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, +): # pylint: disable=invalid-name,unnecessary-lambda-assignment + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( + q, query_padding_mask + ) + + output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 + output_unpad, indices_q, batch_size, seqlen_q + ) + + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + max_seqlen_q = seqlen_q + + output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) + + if kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + return ( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + kv, + output_pad_fn, + ) + + return ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + ) + + +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + cu_seqlens = None + max_seqlen = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = ( + self._prepare_decoder_attention_mask( # pylint: disable=protected-access + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + transformers.logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + None, + cu_seqlens, + max_seqlen, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaDecoderLayer(OriginalLlamaDecoderLayer): + """ + patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py b/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py new file mode 100644 index 000000000..2a653ceb6 --- /dev/null +++ b/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py @@ -0,0 +1,140 @@ +""" +Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention +""" + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import transformers.models.llama.modeling_llama +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + + +def hijack_llama_sdp_attention(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + sdp_attention_forward + ) + + +def sdp_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + if not hasattr(self, "pretraining_tp"): + self.pretraining_tp = 1 + + if self.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + # + # sdp-attn start + # + + with torch.backends.cuda.sdp_kernel(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + # + # sdp-attn end + # + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.pretraining_tp, dim=1 + ) + attn_output = sum( + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp) + ) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 752e204f7..c9d517646 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g """ import logging -import math +import warnings from typing import Optional, Tuple import torch import torch.nn.functional as F import transformers.models.llama.modeling_llama -from torch import nn +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv try: import xformers.ops @@ -21,12 +21,6 @@ def hijack_llama_attention(): transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward -def hijack_llama_sdp_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = ( - sdp_attention_forward - ) - - def xformers_forward( self, hidden_states: torch.Tensor, @@ -81,15 +75,15 @@ def xformers_forward( value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - ( - query_states, - key_states, - ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # [bsz, nh, t, hd] @@ -102,74 +96,50 @@ def xformers_forward( past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads - key_states = transformers.models.llama.modeling_llama.repeat_kv( - key_states, self.num_key_value_groups - ) - value_states = transformers.models.llama.modeling_llama.repeat_kv( - value_states, self.num_key_value_groups - ) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - # We only apply xformers optimizations if we don't need to output the whole attention matrix - if not output_attentions: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) - # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. - # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. - if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, key_states, value_states, attn_bias=None - ) - else: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, - key_states, - value_states, - # attn_bias=attention_mask, - attn_bias=xformers.ops.LowerTriangularMask(), - ) - attn_weights = None + # + # xformers-attn start + # + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) else: - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - # end x-formers vs. not x-formers if-else block + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + # attn_bias=attention_mask, + attn_bias=xformers.ops.LowerTriangularMask(), + ) + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # + # xformers-attn end + # + if self.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split( @@ -182,103 +152,4 @@ def xformers_forward( else: attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value - - -def sdp_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - query_states = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - ( - query_states, - key_states, - ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # We only apply sdp attention if we don't need to output the whole attention matrix - if not output_attentions: - with torch.backends.cuda.sdp_kernel(): - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - attn_weights = None - else: - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0a438ed21..66cf70b64 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -103,7 +103,7 @@ def load_model( ) LOG.info("patching with flash attention") - replace_llama_attn_with_flash_attn() + replace_llama_attn_with_flash_attn(packed=cfg.sample_packing) elif cfg.is_llama_derived_model and cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, @@ -112,9 +112,7 @@ def load_model( LOG.info("patching with xformers attention") hijack_llama_attention() elif cfg.is_llama_derived_model and cfg.sdp_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_sdp_attention, - ) + from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention LOG.info("patching with sdp attention") hijack_llama_sdp_attention()