more test
This commit is contained in:
@@ -127,6 +127,18 @@ def get_seqlens_from_pos_ids(position_ids):
|
||||
)
|
||||
# Calculate the sequence lengths
|
||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||
# Append the padding length to the sequence lengths
|
||||
if padding_length:
|
||||
seq_lengths = torch.cat(
|
||||
[
|
||||
seq_lengths,
|
||||
torch.tensor(
|
||||
[len(row) - torch.sum(seq_lengths)],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
results.append(seq_lengths)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user