misc
This commit is contained in:
@@ -144,7 +144,7 @@ def get_seqlens_from_pos_ids(position_ids):
|
|||||||
results.append(seq_lengths)
|
results.append(seq_lengths)
|
||||||
totalseqlens.append(len(adjusted_row))
|
totalseqlens.append(len(adjusted_row))
|
||||||
|
|
||||||
return results , totalseqlens
|
return results , torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
|
|||||||
@@ -243,33 +243,4 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
arrays = [np.array(item) for item in features[feature]]
|
arrays = [np.array(item) for item in features[feature]]
|
||||||
chunked_data[feature] = np.concatenate(arrays)
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
features = [chunked_data]
|
features = [chunked_data]
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
def _get_document_ids_from_seq_lens(
|
|
||||||
seq_lens: List[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
|
||||||
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
|
|
||||||
Args:
|
|
||||||
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
|
||||||
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
|
||||||
across packs.
|
|
||||||
Returns:
|
|
||||||
Tensor: Document IDs of shape (batch_size, max_seq_len).
|
|
||||||
"""
|
|
||||||
batch_size = len(seq_lens)
|
|
||||||
batch_document_ids = []
|
|
||||||
for sample_idx in range(batch_size):
|
|
||||||
# We assume seq lens sum to max seq lens, so document_ids should be of
|
|
||||||
# shape (max_seq_len, )
|
|
||||||
document_ids = torch.cat(
|
|
||||||
[
|
|
||||||
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
|
|
||||||
for i, seq_len in enumerate(seq_lens[sample_idx])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
batch_document_ids.append(document_ids)
|
|
||||||
batch_document_ids = torch.stack(batch_document_ids)
|
|
||||||
return batch_document_ids
|
|
||||||
Reference in New Issue
Block a user