This commit is contained in:
Sunny Liu
2025-02-02 20:32:03 -05:00
parent 9f6c89b12b
commit e5b36900e4
2 changed files with 2 additions and 31 deletions

View File

@@ -144,7 +144,7 @@ def get_seqlens_from_pos_ids(position_ids):
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results , totalseqlens
return results , torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids):

View File

@@ -243,33 +243,4 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
def _get_document_ids_from_seq_lens(
seq_lens: List[torch.Tensor],
) -> torch.Tensor:
"""
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Document IDs of shape (batch_size, max_seq_len).
"""
batch_size = len(seq_lens)
batch_document_ids = []
for sample_idx in range(batch_size):
# We assume seq lens sum to max seq lens, so document_ids should be of
# shape (max_seq_len, )
document_ids = torch.cat(
[
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
for i, seq_len in enumerate(seq_lens[sample_idx])
]
)
batch_document_ids.append(document_ids)
batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids
return super().__call__(features, return_tensors=return_tensors)