From 10405b9995dc7de86f5c0a582941cef80ea01db5 Mon Sep 17 00:00:00 2001 From: ssmi153 <129111316+ssmi153@users.noreply.github.com> Date: Mon, 7 Aug 2023 03:09:04 +1200 Subject: [PATCH] Update XFormers Attention Monkeypatch to handle Llama-2 70B (GQA) (#339) * Fix XFormers attention for Llama-2 70B (GQA) Updated XFormers MonkeyPatch to handle GQA as used in Llama-2 70B. All the updated code is taken directly from the Transformers library: https://github.com/huggingface/transformers/commit/07360b6c9c9448d619a82798419ed291dfc6ac8f#diff-06392bad3b9e97be9ade60d4ac46f73b6809388f4d507c2ba1384ab872711c51 from their llama_modeling.py file. * Catch configs without pretraining_tp * Whitespace bug fix Command had accidentally been moved out of if-else block. * pre-commit formatting fixes Thanks to @winglian --- .../monkeypatch/llama_attn_hijack_xformers.py | 83 +++++++++++++++---- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 8fa00f43b..02525b7f5 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -7,6 +7,7 @@ import math from typing import Optional, Tuple import torch +import torch.nn.functional as F import transformers.models.llama.modeling_llama from torch import nn @@ -38,21 +39,48 @@ def xformers_forward( # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() - query_states = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) + if not hasattr(self, "pretraining_tp"): + self.pretraining_tp = 1 + + if self.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -73,6 +101,14 @@ def xformers_forward( past_key_value = (key_states, value_states) if use_cache else None + # repeat k/v heads if n_kv_heads < n_heads + key_states = transformers.models.llama.modeling_llama.repeat_kv( + key_states, self.num_key_value_groups + ) + value_states = transformers.models.llama.modeling_llama.repeat_kv( + value_states, self.num_key_value_groups + ) + # We only apply xformers optimizations if we don't need to output the whole attention matrix if not output_attentions: query_states = query_states.transpose(1, 2) @@ -128,10 +164,23 @@ def xformers_forward( f" {attn_output.size()}" ) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() + # end x-formers vs. not x-formers if-else block attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.pretraining_tp, dim=1 + ) + attn_output = sum( + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp) + ) + else: + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value