speed up flash-attn inference

This commit is contained in:
Aman Karmani
2023-08-13 18:03:38 +00:00
committed by Wing Lian
parent d773384f74
commit 13f7efaf74

View File

@@ -16,6 +16,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
@@ -146,7 +147,7 @@ def flashattn_forward(
else: else:
# turn off FA causal mask after first inference autoregressive iteration # turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = past_key_value is not None
if self.training and attention_mask.shape[0] == 1: if self.training and attention_mask.shape[0] == 1:
# special handling using sample packing # special handling using sample packing
@@ -163,14 +164,20 @@ def flashattn_forward(
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states.transpose(1, 2), query_states,
key_states.transpose(1, 2), key_states,
value_states.transpose(1, 2), value_states,
qkvpacked=True, qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -182,35 +189,48 @@ def flashattn_forward(
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)
else: else:
( # pylint: disable=unbalanced-tuple-unpacking query_states = query_states.transpose(1, 2)
q_unpad, key_states = key_states.transpose(1, 2)
kv_unpad, value_states = value_states.transpose(1, 2)
cu_seqlens_q, if attention_mask is None or attention_mask.all().item():
cu_seqlens_k, output = flash_attn_kvpacked_func(
max_seqlen_q, query_states,
max_seqlen_k, torch.stack([key_states, value_states], 2),
_, causal=is_causal,
_, )
output_pad_fn, else:
) = generate_qkv( ( # pylint: disable=unbalanced-tuple-unpacking
query_states.transpose(1, 2), q_unpad,
key_states.transpose(1, 2), kv_unpad,
value_states.transpose(1, 2), cu_seqlens_q,
kvpacked=True, cu_seqlens_k,
key_padding_mask=attention_mask, max_seqlen_q,
) max_seqlen_k,
output_unpad = flash_attn_varlen_kvpacked_func( _,
q_unpad, _,
kv_unpad, output_pad_fn,
cu_seqlens_q, ) = generate_qkv(
cu_seqlens_k, query_states,
max_seqlen_q, key_states,
max_seqlen_k, value_states,
0.0, kvpacked=True,
softmax_scale=None, key_padding_mask=attention_mask,
causal=is_causal, query_padding_mask=attention_mask[:, -query_states.size(1) :]
) if attention_mask is not None
output = output_pad_fn(output_unpad) else None,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):