diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index c2772b471..824a56e33 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -95,6 +95,55 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) +def get_seqlens_from_pos_ids(position_ids): + """generate a sequence length set using pos ids for doc mask creation in flex attention""" + if len(position_ids.shape) == 1: + position_ids = position_ids.unsqueeze(0) + + device = position_ids.device + results = [] + + for row in position_ids: + # Count the number of consecutive zeros from the right side + padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() + + # Adjust the row to exclude padding + adjusted_row = row[:-padding_length] if padding_length else row.clone() + + # Find where the position resets to 0 (indicating a new sequence) + seq_starts = torch.cat( + [ + torch.tensor([True], dtype=torch.bool, device=device), + adjusted_row[1:] == 0, + ] + ) + # Get the indices where the sequence starts + start_indices = torch.cat( + [ + torch.nonzero(seq_starts).unbind(dim=1)[0], + torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = start_indices[1:] - start_indices[:-1] + # Append the padding length to the sequence lengths + if padding_length: + seq_lengths = torch.cat( + [ + seq_lengths, + torch.tensor( + [len(row) - torch.sum(seq_lengths)], + dtype=torch.int32, + device=device, + ), + ] + ) + + results.append(seq_lengths) + + return results + + def get_cu_seqlens_from_pos_ids(position_ids): """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: