From c45a786039dd8ee023b8ccd2f2af8fb6662dc765 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 13 Aug 2023 15:41:06 +0000 Subject: [PATCH] sync xformers patch to follow shared format and be diffable --- .../monkeypatch/llama_attn_hijack_xformers.py | 114 +++++++----------- 1 file changed, 45 insertions(+), 69 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 142d78a91..c9d517646 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g """ import logging -import math +import warnings from typing import Optional, Tuple import torch import torch.nn.functional as F import transformers.models.llama.modeling_llama -from torch import nn +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv try: import xformers.ops @@ -75,15 +75,15 @@ def xformers_forward( value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - ( - query_states, - key_states, - ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) # [bsz, nh, t, hd] @@ -96,74 +96,50 @@ def xformers_forward( past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads - key_states = transformers.models.llama.modeling_llama.repeat_kv( - key_states, self.num_key_value_groups - ) - value_states = transformers.models.llama.modeling_llama.repeat_kv( - value_states, self.num_key_value_groups - ) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - # We only apply xformers optimizations if we don't need to output the whole attention matrix - if not output_attentions: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) - # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. - # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. - if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, key_states, value_states, attn_bias=None - ) - else: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, - key_states, - value_states, - # attn_bias=attention_mask, - attn_bias=xformers.ops.LowerTriangularMask(), - ) - attn_weights = None + # + # xformers-attn start + # + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) else: - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - # end x-formers vs. not x-formers if-else block + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + # attn_bias=attention_mask, + attn_bias=xformers.ops.LowerTriangularMask(), + ) + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # + # xformers-attn end + # + if self.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split( @@ -176,4 +152,4 @@ def xformers_forward( else: attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value