This commit is contained in:
bursteratom
2025-02-02 01:27:15 -05:00
parent e98581f6f5
commit 0ebab63309

View File

@@ -128,7 +128,7 @@ def get_seqlens_from_pos_ids(position_ids):
# Calculate the sequence lengths # Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1] seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths # Append the padding length to the sequence lengths
if padding_length: """if padding_length:
seq_lengths = torch.cat( seq_lengths = torch.cat(
[ [
seq_lengths, seq_lengths,
@@ -138,7 +138,7 @@ def get_seqlens_from_pos_ids(position_ids):
device=device, device=device,
), ),
] ]
) )"""
results.append(seq_lengths) results.append(seq_lengths)