fix xformers inference

This commit is contained in:
Wing Lian
2025-05-05 06:20:06 -04:00
parent 4f478083e7
commit 2e74e1d289

View File

@@ -7,7 +7,6 @@ from typing import Optional
import torch import torch
import xformers import xformers
import xformers.ops.fmha import xformers.ops.fmha
from flash_attn.bert_padding import pad_input
from transformers.modeling_flash_attention_utils import ( from transformers.modeling_flash_attention_utils import (
_upad_input, _upad_input,
) )
@@ -34,14 +33,29 @@ def xformers_attention_forward(
max_length_k: Optional[int] = None, # pylint: disable=unused-argument max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2) query = query.transpose(1, 2)
key = key.transpose(1, 2) key = key.transpose(1, 2)
value = value.transpose(1, 2) value = value.transpose(1, 2)
query_length = query.shape[2]
batch_size = query.size(0)
attn_bias = xformers.ops.LowerTriangularMask() # Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
@@ -64,25 +78,13 @@ def xformers_attention_forward(
key = key.reshape(-1, key.size(-2), key.size(-1)) key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1)) value = value.reshape(-1, value.size(-2), value.size(-1))
# pylint: disable=duplicate-code # Handle GQA
if module.config.num_attention_heads != module.config.num_key_value_heads: if is_gqa:
key = key.repeat_interleave( key = key.repeat_interleave(n_groups, dim=2)
module.config.num_attention_heads // module.config.num_key_value_heads, value = value.repeat_interleave(n_groups, dim=2)
dim=2,
)
value = value.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
elif attention_mask is not None: elif attention_mask is not None:
query, key, value, indices_q, cu_seq_lens, _ = _upad_input( query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length query, key, value, attention_mask, query_length
) )
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
@@ -94,31 +96,63 @@ def xformers_attention_forward(
kv_seqlen=seq_lengths, kv_seqlen=seq_lengths,
) )
attn_output_unpad = xformers_attention( # Handle GQA
query, if is_gqa:
key, key = key.repeat_interleave(n_groups, dim=2)
value, value = value.repeat_interleave(n_groups, dim=2)
attn_bias=attn_bias,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
# pylint: disable=duplicate-code # Handle Group Query Attention (GQA) using view/expand approach from reference
if module.config.num_attention_heads != module.config.num_key_value_heads: key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.repeat_interleave( value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
module.config.num_attention_heads // module.config.num_key_value_heads, key = key.expand(
dim=2, batch_size, key_length, num_key_value_heads, n_groups, head_dim
) )
value = value.repeat_interleave( value = value.expand(
module.config.num_attention_heads // module.config.num_key_value_heads, batch_size, key_length, num_key_value_heads, n_groups, head_dim
dim=2, )
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
) )
attn_output = xformers_attention( # If we need a sliding window attention
query, if has_sliding_window:
key, query = query.view(
value, 1,
attn_bias=attn_bias, batch_size * query_length,
) num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view( attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1) batch_size, -1, attn_output.size(-2), attn_output.size(-1)