more test

This commit is contained in:
bursteratom
2025-02-02 00:48:57 -05:00
parent 2319e5276d
commit b692d394b1

View File

@@ -127,6 +127,18 @@ def get_seqlens_from_pos_ids(position_ids):
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)