speed up flash-attn inference
This commit is contained in:
@@ -16,6 +16,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
flash_attn_varlen_kvpacked_func,
|
flash_attn_varlen_kvpacked_func,
|
||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
@@ -146,7 +147,7 @@ def flashattn_forward(
|
|||||||
else:
|
else:
|
||||||
# turn off FA causal mask after first inference autoregressive iteration
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = key_states.shape == query_states.shape
|
is_causal = past_key_value is not None
|
||||||
|
|
||||||
if self.training and attention_mask.shape[0] == 1:
|
if self.training and attention_mask.shape[0] == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
@@ -163,14 +164,20 @@ def flashattn_forward(
|
|||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
query_states.transpose(1, 2),
|
query_states,
|
||||||
key_states.transpose(1, 2),
|
key_states,
|
||||||
value_states.transpose(1, 2),
|
value_states,
|
||||||
qkvpacked=True,
|
qkvpacked=True,
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
@@ -182,35 +189,48 @@ def flashattn_forward(
|
|||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
else:
|
else:
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
query_states = query_states.transpose(1, 2)
|
||||||
q_unpad,
|
key_states = key_states.transpose(1, 2)
|
||||||
kv_unpad,
|
value_states = value_states.transpose(1, 2)
|
||||||
cu_seqlens_q,
|
if attention_mask is None or attention_mask.all().item():
|
||||||
cu_seqlens_k,
|
output = flash_attn_kvpacked_func(
|
||||||
max_seqlen_q,
|
query_states,
|
||||||
max_seqlen_k,
|
torch.stack([key_states, value_states], 2),
|
||||||
_,
|
causal=is_causal,
|
||||||
_,
|
)
|
||||||
output_pad_fn,
|
else:
|
||||||
) = generate_qkv(
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
query_states.transpose(1, 2),
|
q_unpad,
|
||||||
key_states.transpose(1, 2),
|
kv_unpad,
|
||||||
value_states.transpose(1, 2),
|
cu_seqlens_q,
|
||||||
kvpacked=True,
|
cu_seqlens_k,
|
||||||
key_padding_mask=attention_mask,
|
max_seqlen_q,
|
||||||
)
|
max_seqlen_k,
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
_,
|
||||||
q_unpad,
|
_,
|
||||||
kv_unpad,
|
output_pad_fn,
|
||||||
cu_seqlens_q,
|
) = generate_qkv(
|
||||||
cu_seqlens_k,
|
query_states,
|
||||||
max_seqlen_q,
|
key_states,
|
||||||
max_seqlen_k,
|
value_states,
|
||||||
0.0,
|
kvpacked=True,
|
||||||
softmax_scale=None,
|
key_padding_mask=attention_mask,
|
||||||
causal=is_causal,
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
)
|
if attention_mask is not None
|
||||||
output = output_pad_fn(output_unpad)
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
attn_output = output
|
attn_output = output
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
|||||||
Reference in New Issue
Block a user