Compare commits

...

2 Commits

Author SHA1 Message Date
NanoCode012
7888a35118 chore: remove unused log 2025-03-31 16:20:15 +07:00
NanoCode012
873385b7d5 feat: update xformers for new attention interface 2025-03-31 16:15:55 +07:00

View File

@@ -1,153 +1,113 @@
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
Hijack the LlamaAttention forward method to use xformers if available.
Updated for transformers v4.50.0.
"""
import logging
import warnings
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv
try:
import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
XFORMERS_AVAILABLE = False
def xformers_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs, # pylint: disable=unused-argument
):
"""
Implements xformers memory-efficient attention for LlamaAttention with support for GQA.
Args:
module: The LlamaAttention module
query: Query states of shape [batch, num_heads, seq_len, head_dim]
key: Key states of shape [batch, num_kv_heads, seq_len, head_dim]
value: Value states of shape [batch, num_kv_heads, seq_len, head_dim]
attention_mask: Attention mask
scaling: Scaling factor for attention scores
dropout: Dropout probability
Returns:
attn_output: Output of xformers memory-efficient attention
attn_weights: None
"""
# First, handle grouped-query attention (GQA)
# We need to repeat key and value states to match the number of query heads
num_key_value_groups = getattr(module, "num_key_value_groups", 1)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
# xformers expects inputs in shape [batch, seq_len, num_heads, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Determine if we need a causal mask
is_causal = getattr(module, "is_causal", True)
# Set up the attention bias for xformers
if is_causal:
# Use xformers built-in causal mask
attn_bias = xformers.ops.LowerTriangularMask()
elif attention_mask is not None:
# For non-causal attention with a mask, we'd need to convert the mask
# This is a simplification - you might need to adapt based on your mask format
attn_bias = attention_mask
else:
# No mask needed
attn_bias = None
# Apply xformers memory-efficient attention
attn_output = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attn_bias,
p=dropout if module.training else 0.0,
scale=scaling,
)
# Reshape back to [batch, seq_len, hidden_size]
attn_output = attn_output.transpose(1, 2)
return attn_output, None # Return None for attn_weights to match interface
def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# xformers-attn start
#
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
"""
Patch the LlamaAttention forward method to use xformers if available.
"""
if not XFORMERS_AVAILABLE:
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
"xformers not available. Please install it following axolotl's requirements."
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# xformers-attn end
#
import transformers.models.llama.modeling_llama as llama_modeling
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
# Add xformers to the available attention implementations
llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward
return attn_output, None, past_key_value
# Create a wrapper for the original LlamaAttention forward method
original_forward = llama_modeling.LlamaAttention.forward
def patched_forward(self, *args, **kwargs):
# Set the attention implementation to xformers
# pylint: disable=protected-access
self.config._attn_implementation = "xformers"
return original_forward(self, *args, **kwargs)
# Apply the patch
llama_modeling.LlamaAttention.forward = patched_forward