From 6cb23105929b62d46f0ab4fdcf4f27e4fe12eaa8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 23:34:36 -0400 Subject: [PATCH 1/6] copy xformers attn from ooba since we removed dep on alpaca_lora_4bit --- .../monkeypatch/llama_attn_hijack_xformers.py | 172 ++++++++++++++++++ src/axolotl/utils/models.py | 9 +- 2 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/monkeypatch/llama_attn_hijack_xformers.py diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py new file mode 100644 index 000000000..a4f34bcd7 --- /dev/null +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -0,0 +1,172 @@ +''' +Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments +''' + +import logging +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import transformers.models.llama.modeling_llama + +try: + import xformers.ops +except Exception: + logging.error("xformers not found! Please install it before trying to use it.") + + +def hijack_llama_attention(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + logging.info("Replaced attention with xformers_attention") + + +def hijack_llama_sdp_attention(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward + logging.info("Replaced attention with sdp_attention") + + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask()) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +def sdp_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # We only apply sdp attention if we don't need to output the whole attention matrix + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0737d0f12..df4e50be5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -97,12 +97,19 @@ def load_model( logging.info("patching with flash attention") replace_llama_attn_with_flash_attn() elif is_llama_derived_model and cfg.xformers_attention: - from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import ( + from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, ) logging.info("patching with xformers attention") hijack_llama_attention() + elif is_llama_derived_model and cfg.sdp_attention: + from axolotl.monkeypatch.llama_attn_hijack_xformers import ( + hijack_llama_sdp_attention, + ) + + logging.info("patching with sdp attention") + hijack_llama_sdp_attention() if cfg.bf16: torch_dtype = torch.bfloat16 From ad0ea6aaabc8276270ccb7bd7f264602f4d4e569 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 23:36:01 -0400 Subject: [PATCH 2/6] black formatting ignore copied file fix linting --- .mypy.ini | 6 + .../monkeypatch/llama_attn_hijack_xformers.py | 105 ++++++++++++++---- 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 941046ae8..c542178e0 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -5,6 +5,9 @@ exclude = venv [mypy-alpaca_lora_4bit.*] ignore_missing_imports = True +[mypy-axolotl.monkeypatch.*] +ignore_errors = True + [mypy-flash_attn.*] ignore_missing_imports = True @@ -31,3 +34,6 @@ ignore_missing_imports = True [mypy-addict] ignore_missing_imports = True + +[mypy-xformers.*] +ignore_missing_imports = True diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index a4f34bcd7..bb5728ef1 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -1,18 +1,18 @@ -''' +""" Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments -''' +""" import logging import math from typing import Optional, Tuple import torch -import torch.nn as nn import transformers.models.llama.modeling_llama +from torch import nn try: import xformers.ops -except Exception: +except ImportError: logging.error("xformers not found! Please install it before trying to use it.") @@ -22,7 +22,9 @@ def hijack_llama_attention(): def hijack_llama_sdp_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + sdp_attention_forward + ) logging.info("Replaced attention with sdp_attention") @@ -37,15 +39,32 @@ def xformers_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) # [bsz, nh, t, hd] if past_key_value is not None: @@ -65,13 +84,22 @@ def xformers_forward( # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) else: # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask()) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + ) attn_weights = None else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -85,10 +113,14 @@ def xformers_forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -115,15 +147,32 @@ def sdp_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) # [bsz, nh, t, hd] if past_key_value is not None: @@ -135,10 +184,18 @@ def sdp_attention_forward( # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: - attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) attn_weights = None else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -152,10 +209,14 @@ def sdp_attention_forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): From 2daa6835f00d8e29927598eefdc8dcf70a3605e8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 23:59:05 -0400 Subject: [PATCH 3/6] Update src/axolotl/monkeypatch/llama_attn_hijack_xformers.py Co-authored-by: NanoCode012 --- src/axolotl/monkeypatch/llama_attn_hijack_xformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index bb5728ef1..d4b7165aa 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -25,7 +25,6 @@ def hijack_llama_sdp_attention(): transformers.models.llama.modeling_llama.LlamaAttention.forward = ( sdp_attention_forward ) - logging.info("Replaced attention with sdp_attention") def xformers_forward( From 1076bcbbca68ec101bda719f882803e0615d77f3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 00:00:19 -0400 Subject: [PATCH 4/6] Update src/axolotl/monkeypatch/llama_attn_hijack_xformers.py Co-authored-by: NanoCode012 --- src/axolotl/monkeypatch/llama_attn_hijack_xformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index d4b7165aa..ee013bd30 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -18,7 +18,6 @@ except ImportError: def hijack_llama_attention(): transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward - logging.info("Replaced attention with xformers_attention") def hijack_llama_sdp_attention(): From 2675fb756e13f4e6b7184628c4152929b0ff42c2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 00:02:29 -0400 Subject: [PATCH 5/6] update readme for SDP --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index e1391e39b..853681769 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,9 @@ weight_decay: xformers_attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention: flash_attention: # require a100 for llama +# whether to use scaled-dot-product attention +# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html +sdp_attention: # resume from a specific checkpoint dir resume_from_checkpoint: From c56818b11978bafac83ed5e4949cd4d9aa0f6326 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 00:06:47 -0400 Subject: [PATCH 6/6] don't worry about dupes --- src/axolotl/flash_attn.py | 1 + src/axolotl/monkeypatch/llama_attn_hijack_xformers.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index 6df0b8e18..406dd15ad 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -25,6 +25,7 @@ def forward( attention_mask: [bsz, q_len] """ + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index ee013bd30..c6bdafb89 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -35,6 +35,7 @@ def xformers_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( @@ -143,6 +144,7 @@ def sdp_attention_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = (