sample packing doc mask creation WIP

This commit is contained in:
bursteratom
2025-01-21 09:18:38 -05:00
parent 80bfc50d1f
commit b2a34380b3

View File

@@ -3,12 +3,15 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
@dataclass
class DataCollatorForSeq2Seq:
@@ -166,12 +169,13 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features_)
seq_len_list = [
get_seqlens_from_pos_ids(item["position_ids"])
for item in features_
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
out_features[i][feature] = np.concatenate(seq_len_list)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
@@ -238,3 +242,32 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
def _get_document_ids_from_seq_lens(
seq_lens: List[torch.Tensor],
) -> torch.Tensor:
"""
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Document IDs of shape (batch_size, max_seq_len).
"""
batch_size = len(seq_lens)
batch_document_ids = []
for sample_idx in range(batch_size):
# We assume seq lens sum to max seq lens, so document_ids should be of
# shape (max_seq_len, )
document_ids = torch.cat(
[
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
for i, seq_len in enumerate(seq_lens[sample_idx])
]
)
batch_document_ids.append(document_ids)
batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids