diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index f23fcf7fa..eeab5d564 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -127,6 +127,18 @@ def get_seqlens_from_pos_ids(position_ids): ) # Calculate the sequence lengths seq_lengths = start_indices[1:] - start_indices[:-1] + # Append the padding length to the sequence lengths + if padding_length: + seq_lengths = torch.cat( + [ + seq_lengths, + torch.tensor( + [len(row) - torch.sum(seq_lengths)], + dtype=torch.int32, + device=device, + ), + ] + ) results.append(seq_lengths)