get seqlens from position ids for foc masking
This commit is contained in:
@@ -95,6 +95,55 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_seqlens_from_pos_ids(position_ids):
|
||||||
|
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
||||||
|
if len(position_ids.shape) == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
device = position_ids.device
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for row in position_ids:
|
||||||
|
# Count the number of consecutive zeros from the right side
|
||||||
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||||
|
|
||||||
|
# Adjust the row to exclude padding
|
||||||
|
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||||
|
|
||||||
|
# Find where the position resets to 0 (indicating a new sequence)
|
||||||
|
seq_starts = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([True], dtype=torch.bool, device=device),
|
||||||
|
adjusted_row[1:] == 0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence starts
|
||||||
|
start_indices = torch.cat(
|
||||||
|
[
|
||||||
|
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
||||||
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
|
# Append the padding length to the sequence lengths
|
||||||
|
if padding_length:
|
||||||
|
seq_lengths = torch.cat(
|
||||||
|
[
|
||||||
|
seq_lengths,
|
||||||
|
torch.tensor(
|
||||||
|
[len(row) - torch.sum(seq_lengths)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(seq_lengths)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user