diff --git a/.mypy.ini b/.mypy.ini index 941046ae8..c542178e0 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -5,6 +5,9 @@ exclude = venv [mypy-alpaca_lora_4bit.*] ignore_missing_imports = True +[mypy-axolotl.monkeypatch.*] +ignore_errors = True + [mypy-flash_attn.*] ignore_missing_imports = True @@ -31,3 +34,6 @@ ignore_missing_imports = True [mypy-addict] ignore_missing_imports = True + +[mypy-xformers.*] +ignore_missing_imports = True diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index a4f34bcd7..bb5728ef1 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -1,18 +1,18 @@ -''' +""" Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments -''' +""" import logging import math from typing import Optional, Tuple import torch -import torch.nn as nn import transformers.models.llama.modeling_llama +from torch import nn try: import xformers.ops -except Exception: +except ImportError: logging.error("xformers not found! Please install it before trying to use it.") @@ -22,7 +22,9 @@ def hijack_llama_attention(): def hijack_llama_sdp_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward + transformers.models.llama.modeling_llama.LlamaAttention.forward = ( + sdp_attention_forward + ) logging.info("Replaced attention with sdp_attention") @@ -37,15 +39,32 @@ def xformers_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) # [bsz, nh, t, hd] if past_key_value is not None: @@ -65,13 +84,22 @@ def xformers_forward( # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) else: # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask()) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + ) attn_weights = None else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -85,10 +113,14 @@ def xformers_forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -115,15 +147,32 @@ def sdp_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) # [bsz, nh, t, hd] if past_key_value is not None: @@ -135,10 +184,18 @@ def sdp_attention_forward( # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: - attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) attn_weights = None else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -152,10 +209,14 @@ def sdp_attention_forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):