optimize calculation of cu_seqlens from position_ids (#1084) [skip ci]
This commit is contained in:
@@ -55,6 +55,7 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
@@ -81,7 +82,7 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
|||||||
# Get the indices where the sequence starts
|
# Get the indices where the sequence starts
|
||||||
start_indices = torch.cat(
|
start_indices = torch.cat(
|
||||||
[
|
[
|
||||||
(seq_starts).nonzero(as_tuple=True)[0],
|
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user