fix xformers inference
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops.fmha
|
import xformers.ops.fmha
|
||||||
from flash_attn.bert_padding import pad_input
|
|
||||||
from transformers.modeling_flash_attention_utils import (
|
from transformers.modeling_flash_attention_utils import (
|
||||||
_upad_input,
|
_upad_input,
|
||||||
)
|
)
|
||||||
@@ -34,14 +33,29 @@ def xformers_attention_forward(
|
|||||||
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
|
# Get dimensions
|
||||||
|
# query: [batch, heads, seq_len, hidden_dim]
|
||||||
|
batch_size = query.size(0)
|
||||||
|
query_length = query.shape[2]
|
||||||
|
key_length = key.shape[2]
|
||||||
|
|
||||||
|
# Default causal mask
|
||||||
|
attn_bias = xformers.ops.LowerTriangularMask()
|
||||||
|
|
||||||
|
# Check if we have sliding window attention
|
||||||
|
has_sliding_window = sliding_window is not None and sliding_window < query_length
|
||||||
|
|
||||||
|
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
|
||||||
query = query.transpose(1, 2)
|
query = query.transpose(1, 2)
|
||||||
key = key.transpose(1, 2)
|
key = key.transpose(1, 2)
|
||||||
value = value.transpose(1, 2)
|
value = value.transpose(1, 2)
|
||||||
query_length = query.shape[2]
|
|
||||||
batch_size = query.size(0)
|
|
||||||
|
|
||||||
attn_bias = xformers.ops.LowerTriangularMask()
|
# Get GQA parameters
|
||||||
|
num_attention_heads = module.config.num_attention_heads
|
||||||
|
num_key_value_heads = module.config.num_key_value_heads
|
||||||
|
head_dim = query.size(-1)
|
||||||
|
is_gqa = num_attention_heads != num_key_value_heads
|
||||||
|
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
|
||||||
|
|
||||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||||
@@ -64,25 +78,13 @@ def xformers_attention_forward(
|
|||||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||||
value = value.reshape(-1, value.size(-2), value.size(-1))
|
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# Handle GQA
|
||||||
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
if is_gqa:
|
||||||
key = key.repeat_interleave(
|
key = key.repeat_interleave(n_groups, dim=2)
|
||||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
value = value.repeat_interleave(n_groups, dim=2)
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
value = value.repeat_interleave(
|
|
||||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = xformers_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_bias=attn_bias,
|
|
||||||
)
|
|
||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
query, key, value, _, cu_seq_lens, _ = _upad_input(
|
||||||
query, key, value, attention_mask, query_length
|
query, key, value, attention_mask, query_length
|
||||||
)
|
)
|
||||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||||
@@ -94,31 +96,63 @@ def xformers_attention_forward(
|
|||||||
kv_seqlen=seq_lengths,
|
kv_seqlen=seq_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output_unpad = xformers_attention(
|
# Handle GQA
|
||||||
query,
|
if is_gqa:
|
||||||
key,
|
key = key.repeat_interleave(n_groups, dim=2)
|
||||||
value,
|
value = value.repeat_interleave(n_groups, dim=2)
|
||||||
attn_bias=attn_bias,
|
|
||||||
)
|
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
||||||
else:
|
else:
|
||||||
# pylint: disable=duplicate-code
|
# Handle Group Query Attention (GQA) using view/expand approach from reference
|
||||||
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||||
key = key.repeat_interleave(
|
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
key = key.expand(
|
||||||
dim=2,
|
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||||
)
|
)
|
||||||
value = value.repeat_interleave(
|
value = value.expand(
|
||||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||||
dim=2,
|
)
|
||||||
|
|
||||||
|
if module.training:
|
||||||
|
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||||
|
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||||
|
|
||||||
|
if has_sliding_window:
|
||||||
|
query = query.view(
|
||||||
|
1, batch_size * query_length, num_attention_heads, head_dim
|
||||||
|
)
|
||||||
|
key = key.view(
|
||||||
|
1, batch_size * key_length, num_attention_heads, head_dim
|
||||||
|
)
|
||||||
|
value = value.view(
|
||||||
|
1, batch_size * key_length, num_attention_heads, head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.view(
|
||||||
|
batch_size, query_length, num_key_value_heads, n_groups, head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = xformers_attention(
|
# If we need a sliding window attention
|
||||||
query,
|
if has_sliding_window:
|
||||||
key,
|
query = query.view(
|
||||||
value,
|
1,
|
||||||
attn_bias=attn_bias,
|
batch_size * query_length,
|
||||||
)
|
num_key_value_heads,
|
||||||
|
n_groups,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
key = key.view(
|
||||||
|
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
value = value.view(
|
||||||
|
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the xformers attention
|
||||||
|
attn_output = xformers_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(
|
attn_output = attn_output.view(
|
||||||
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
|
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user