black formatting

ignore copied file
fix linting
This commit is contained in:
Wing Lian
2023-05-30 23:36:01 -04:00
parent 6cb2310592
commit ad0ea6aaab
2 changed files with 89 additions and 22 deletions

View File

@@ -5,6 +5,9 @@ exclude = venv
[mypy-alpaca_lora_4bit.*] [mypy-alpaca_lora_4bit.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-flash_attn.*] [mypy-flash_attn.*]
ignore_missing_imports = True ignore_missing_imports = True
@@ -31,3 +34,6 @@ ignore_missing_imports = True
[mypy-addict] [mypy-addict]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-xformers.*]
ignore_missing_imports = True

View File

@@ -1,18 +1,18 @@
''' """
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
''' """
import logging import logging
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from torch import nn
try: try:
import xformers.ops import xformers.ops
except Exception: except ImportError:
logging.error("xformers not found! Please install it before trying to use it.") logging.error("xformers not found! Please install it before trying to use it.")
@@ -22,7 +22,9 @@ def hijack_llama_attention():
def hijack_llama_sdp_attention(): def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
logging.info("Replaced attention with sdp_attention") logging.info("Replaced attention with sdp_attention")
@@ -37,15 +39,32 @@ def xformers_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = (
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) self.q_proj(hidden_states)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) .view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) (
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd] # [bsz, nh, t, hd]
if past_key_value is not None: if past_key_value is not None:
@@ -65,13 +84,22 @@ def xformers_forward(
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim) # input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else: else:
# input and output should be of form (bsz, q_len, num_heads, head_dim) # input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask()) attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None attn_weights = None
else: else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( raise ValueError(
@@ -85,10 +113,14 @@ def xformers_forward(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -115,15 +147,32 @@ def sdp_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = (
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) self.q_proj(hidden_states)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) .view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) (
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd] # [bsz, nh, t, hd]
if past_key_value is not None: if past_key_value is not None:
@@ -135,10 +184,18 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix # We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False) attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None attn_weights = None
else: else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( raise ValueError(
@@ -152,10 +209,14 @@ def sdp_attention_forward(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):