From b2a34380b335db72a38350a9f9274520d61c76a2 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Tue, 21 Jan 2025 09:18:38 -0500 Subject: [PATCH] sample packing doc mask creation WIP --- src/axolotl/utils/collators/batching.py | 43 ++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 13a6e1967..579fbfbb3 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -3,12 +3,15 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences """ from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import numpy as np +import torch from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids + @dataclass class DataCollatorForSeq2Seq: @@ -166,12 +169,13 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if feature == "length": continue if feature == "attention_mask": - arrays = [ - (i + 1) * np.array(item[feature]) - for i, item in enumerate(features_) + seq_len_list = [ + get_seqlens_from_pos_ids(item["position_ids"]) + for item in features_ if feature in item ] - out_features[i][feature] = np.concatenate(arrays) + + out_features[i][feature] = np.concatenate(seq_len_list) else: arrays = [ np.array(item[feature]) for item in features_ if feature in item @@ -238,3 +242,32 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) + + +def _get_document_ids_from_seq_lens( + seq_lens: List[torch.Tensor], +) -> torch.Tensor: + """ + Convert a batch tensor of seq lens into integer IDs denoting sample ownership. + For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2]. + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + Returns: + Tensor: Document IDs of shape (batch_size, max_seq_len). + """ + batch_size = len(seq_lens) + batch_document_ids = [] + for sample_idx in range(batch_size): + # We assume seq lens sum to max seq lens, so document_ids should be of + # shape (max_seq_len, ) + document_ids = torch.cat( + [ + torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device) + for i, seq_len in enumerate(seq_lens[sample_idx]) + ] + ) + batch_document_ids.append(document_ids) + batch_document_ids = torch.stack(batch_document_ids) + return batch_document_ids