fix xformers inference

This commit is contained in:
Wing Lian
2025-05-05 06:20:06 -04:00
parent 4f478083e7
commit 2e74e1d289

View File

@@ -7,7 +7,6 @@ from typing import Optional
import torch
import xformers
import xformers.ops.fmha
from flash_attn.bert_padding import pad_input
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
@@ -34,14 +33,29 @@ def xformers_attention_forward(
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_length = query.shape[2]
batch_size = query.size(0)
attn_bias = xformers.ops.LowerTriangularMask()
# Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
@@ -64,25 +78,13 @@ def xformers_attention_forward(
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# pylint: disable=duplicate-code
if module.config.num_attention_heads != module.config.num_key_value_heads:
key = key.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
value = value.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
elif attention_mask is not None:
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
@@ -94,31 +96,63 @@ def xformers_attention_forward(
kv_seqlen=seq_lengths,
)
attn_output_unpad = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
else:
# pylint: disable=duplicate-code
if module.config.num_attention_heads != module.config.num_key_value_heads:
key = key.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
value = value.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
# Handle Group Query Attention (GQA) using view/expand approach from reference
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
value = value.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
)
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
# If we need a sliding window attention
if has_sliding_window:
query = query.view(
1,
batch_size * query_length,
num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)