From 873385b7d59dff5950f9289fd9f0b80cb092c60f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 31 Mar 2025 16:15:55 +0700 Subject: [PATCH] feat: update xformers for new attention interface --- .../monkeypatch/llama_attn_hijack_xformers.py | 233 ++++++++---------- 1 file changed, 98 insertions(+), 135 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 0c1a4e822..dd4df168d 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -1,153 +1,116 @@ """ -Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments +Hijack the LlamaAttention forward method to use xformers if available. + +Updated for transformers v4.50.0. """ import logging -import warnings -from typing import Optional, Tuple +from typing import Optional import torch -import torch.nn.functional as F -import transformers.models.llama.modeling_llama -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from torch import nn +from transformers.models.llama.modeling_llama import repeat_kv try: import xformers.ops + + XFORMERS_AVAILABLE = True except ImportError: - logging.error("xformers not found! Please install it before trying to use it.") + XFORMERS_AVAILABLE = False + + +def xformers_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, # pylint: disable=unused-argument +): + """ + Implements xformers memory-efficient attention for LlamaAttention with support for GQA. + + Args: + module: The LlamaAttention module + query: Query states of shape [batch, num_heads, seq_len, head_dim] + key: Key states of shape [batch, num_kv_heads, seq_len, head_dim] + value: Value states of shape [batch, num_kv_heads, seq_len, head_dim] + attention_mask: Attention mask + scaling: Scaling factor for attention scores + dropout: Dropout probability + + Returns: + attn_output: Output of xformers memory-efficient attention + attn_weights: None + """ + # First, handle grouped-query attention (GQA) + # We need to repeat key and value states to match the number of query heads + num_key_value_groups = getattr(module, "num_key_value_groups", 1) + key = repeat_kv(key, num_key_value_groups) + value = repeat_kv(value, num_key_value_groups) + + # xformers expects inputs in shape [batch, seq_len, num_heads, head_dim] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Determine if we need a causal mask + is_causal = getattr(module, "is_causal", True) + + # Set up the attention bias for xformers + if is_causal: + # Use xformers built-in causal mask + attn_bias = xformers.ops.LowerTriangularMask() + elif attention_mask is not None: + # For non-causal attention with a mask, we'd need to convert the mask + # This is a simplification - you might need to adapt based on your mask format + attn_bias = attention_mask + else: + # No mask needed + attn_bias = None + + # Apply xformers memory-efficient attention + attn_output = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attn_bias, + p=dropout if module.training else 0.0, + scale=scaling, + ) + + # Reshape back to [batch, seq_len, hidden_size] + attn_output = attn_output.transpose(1, 2) + + return attn_output, None # Return None for attn_weights to match interface def hijack_llama_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward - - -def xformers_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - cos, sin = self.rotary_emb(value_states) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # xformers-attn start - # - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. - # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. - if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, key_states, value_states, attn_bias=None - ) - else: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, - key_states, - value_states, - # attn_bias=attention_mask, - attn_bias=xformers.ops.LowerTriangularMask(), - ) - - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + """ + Patch the LlamaAttention forward method to use xformers if available. + """ + if not XFORMERS_AVAILABLE: raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" + "xformers not available. Please install it following axolotl's requirements." ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - # - # xformers-attn end - # + import transformers.models.llama.modeling_llama as llama_modeling - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) + # Add xformers to the available attention implementations + llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward - return attn_output, None, past_key_value + # Create a wrapper for the original LlamaAttention forward method + original_forward = llama_modeling.LlamaAttention.forward + + def patched_forward(self, *args, **kwargs): + # Set the attention implementation to xformers + # pylint: disable=protected-access + self.config._attn_implementation = "xformers" + return original_forward(self, *args, **kwargs) + + # Apply the patch + llama_modeling.LlamaAttention.forward = patched_forward + + logging.info("Successfully patched LlamaAttention with xformers")