use cumulative seq len with var len flash attn v2 w packing
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||||
@@ -19,6 +18,34 @@ except ImportError:
|
|||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens(attn_mask):
|
||||||
|
device = attn_mask.device
|
||||||
|
# Exclude zeros to avoid adding their positions to the mask
|
||||||
|
t_non_zeros = attn_mask[attn_mask != 0]
|
||||||
|
# Find where the sequence number changes (including the first position)
|
||||||
|
seq_change = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([1], dtype=torch.int32, device=device),
|
||||||
|
t_non_zeros[1:] != t_non_zeros[:-1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence changes
|
||||||
|
change_indices = torch.cat(
|
||||||
|
[
|
||||||
|
(seq_change == 1).nonzero(as_tuple=True)[0],
|
||||||
|
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = change_indices[1:] - change_indices[:-1]
|
||||||
|
# Calculate the cumulative sequence lengths
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
|
)
|
||||||
|
|
||||||
|
return cu_seqlens.to(dtype=torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -91,35 +118,15 @@ def forward(
|
|||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
else:
|
else:
|
||||||
nheads = qkv.shape[-2]
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
max_s = q_len
|
||||||
|
cu_q_lens = get_cu_seqlens(key_padding_mask)
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
|
||||||
x_unpad = rearrange(
|
|
||||||
x_unpad,
|
|
||||||
"nnz (three h d) -> nnz three h d",
|
|
||||||
three=3,
|
|
||||||
h=nheads,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
x_unpad,
|
|
||||||
cu_q_lens,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = rearrange(
|
|
||||||
pad_input(
|
|
||||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
|
||||||
indices,
|
|
||||||
bsz,
|
|
||||||
q_len,
|
|
||||||
),
|
|
||||||
"b s (h d) -> b s h d",
|
|
||||||
h=nheads,
|
|
||||||
)
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||||
None,
|
None,
|
||||||
|
|||||||
Reference in New Issue
Block a user