diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 8fa00f43b..4d3d6b68e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -9,6 +9,7 @@ from typing import Optional, Tuple import torch import transformers.models.llama.modeling_llama from torch import nn +import torch.nn.functional as F try: import xformers.ops @@ -38,19 +39,39 @@ def xformers_forward( # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = ( - self.q_proj(hidden_states) + query_states .view(bsz, q_len, self.num_heads, self.head_dim) .transpose(1, 2) ) key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) + key_states + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) + value_states + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) @@ -73,6 +94,10 @@ def xformers_forward( past_key_value = (key_states, value_states) if use_cache else None + # repeat k/v heads if n_kv_heads < n_heads + key_states = transformers.models.llama.modeling_llama.repeat_kv(key_states, self.num_key_value_groups) + value_states = transformers.models.llama.modeling_llama.repeat_kv(value_states, self.num_key_value_groups) + # We only apply xformers optimizations if we don't need to output the whole attention matrix if not output_attentions: query_states = query_states.transpose(1, 2) @@ -128,10 +153,16 @@ def xformers_forward( f" {attn_output.size()}" ) - attn_output = attn_output.transpose(1, 2) - + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value