From e5b36900e48311a5c1b42d36e38307fc06192e10 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Sun, 2 Feb 2025 20:32:03 -0500 Subject: [PATCH] misc --- src/axolotl/monkeypatch/utils.py | 2 +- src/axolotl/utils/collators/batching.py | 31 +------------------------ 2 files changed, 2 insertions(+), 31 deletions(-) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index dcc0f5645..4665a54d4 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -144,7 +144,7 @@ def get_seqlens_from_pos_ids(position_ids): results.append(seq_lengths) totalseqlens.append(len(adjusted_row)) - return results , totalseqlens + return results , torch.tensor(totalseqlens, dtype=torch.int32, device=device) def get_cu_seqlens_from_pos_ids(position_ids): diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 5a4d081de..21dc26945 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -243,33 +243,4 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): arrays = [np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] - return super().__call__(features, return_tensors=return_tensors) - - -def _get_document_ids_from_seq_lens( - seq_lens: List[torch.Tensor], -) -> torch.Tensor: - """ - Convert a batch tensor of seq lens into integer IDs denoting sample ownership. - For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2]. - Args: - seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, - shape (batch_size, n), where n is the max number of sequences in a pack and can vary - across packs. - Returns: - Tensor: Document IDs of shape (batch_size, max_seq_len). - """ - batch_size = len(seq_lens) - batch_document_ids = [] - for sample_idx in range(batch_size): - # We assume seq lens sum to max seq lens, so document_ids should be of - # shape (max_seq_len, ) - document_ids = torch.cat( - [ - torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device) - for i, seq_len in enumerate(seq_lens[sample_idx]) - ] - ) - batch_document_ids.append(document_ids) - batch_document_ids = torch.stack(batch_document_ids) - return batch_document_ids + return super().__call__(features, return_tensors=return_tensors) \ No newline at end of file