From 13f7efaf74fcd3c4514277ccb71914c589873f6a Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 13 Aug 2023 18:03:38 +0000 Subject: [PATCH] speed up flash-attn inference --- .../monkeypatch/llama_attn_hijack_flash.py | 86 ++++++++++++------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 3e94a07cb..4c48a76fd 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -16,6 +16,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids try: from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports + flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, ) @@ -146,7 +147,7 @@ def flashattn_forward( else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape + is_causal = past_key_value is not None if self.training and attention_mask.shape[0] == 1: # special handling using sample packing @@ -163,14 +164,20 @@ def flashattn_forward( ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) elif query_states.shape == key_states.shape: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states.transpose(1, 2), - key_states.transpose(1, 2), - value_states.transpose(1, 2), + query_states, + key_states, + value_states, qkvpacked=True, # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask=attention_mask, + query_padding_mask=attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None, ) output_unpad = flash_attn_varlen_qkvpacked_func( qkv_unpad, @@ -182,35 +189,48 @@ def flashattn_forward( ) output = output_pad_fn(output_unpad) else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states.transpose(1, 2), - key_states.transpose(1, 2), - value_states.transpose(1, 2), - kvpacked=True, - key_padding_mask=attention_mask, - ) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - 0.0, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if attention_mask is None or attention_mask.all().item(): + output = flash_attn_kvpacked_func( + query_states, + torch.stack([key_states, value_states], 2), + causal=is_causal, + ) + else: + ( # pylint: disable=unbalanced-tuple-unpacking + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + _, + _, + output_pad_fn, + ) = generate_qkv( + query_states, + key_states, + value_states, + kvpacked=True, + key_padding_mask=attention_mask, + query_padding_mask=attention_mask[:, -query_states.size(1) :] + if attention_mask is not None + else None, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale=None, + causal=is_causal, + ) + output = output_pad_fn(output_unpad) attn_output = output if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):