use cumulative seq len with var len flash attn v2 w packing

This commit is contained in:
Wing Lian
2023-08-03 15:50:13 -04:00
parent b8905e2a91
commit b2f7bc7ccd

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
@@ -19,6 +18,34 @@ except ImportError:
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def get_cu_seqlens(attn_mask):
device = attn_mask.device
# Exclude zeros to avoid adding their positions to the mask
t_non_zeros = attn_mask[attn_mask != 0]
# Find where the sequence number changes (including the first position)
seq_change = torch.cat(
[
torch.tensor([1], dtype=torch.int32, device=device),
t_non_zeros[1:] != t_non_zeros[:-1],
]
)
# Get the indices where the sequence changes
change_indices = torch.cat(
[
(seq_change == 1).nonzero(as_tuple=True)[0],
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
return cu_seqlens.to(dtype=torch.int32)
def forward(
self,
hidden_states: torch.Tensor,
@@ -91,35 +118,15 @@ def forward(
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = get_cu_seqlens(key_padding_mask)
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,