sample packing doc mask creation WIP
This commit is contained in:
@@ -3,12 +3,15 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -166,12 +169,13 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature == "length":
|
if feature == "length":
|
||||||
continue
|
continue
|
||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
arrays = [
|
seq_len_list = [
|
||||||
(i + 1) * np.array(item[feature])
|
get_seqlens_from_pos_ids(item["position_ids"])
|
||||||
for i, item in enumerate(features_)
|
for item in features_
|
||||||
if feature in item
|
if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
|
||||||
|
out_features[i][feature] = np.concatenate(seq_len_list)
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
@@ -238,3 +242,32 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
chunked_data[feature] = np.concatenate(arrays)
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
features = [chunked_data]
|
features = [chunked_data]
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_document_ids_from_seq_lens(
|
||||||
|
seq_lens: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
||||||
|
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
|
||||||
|
Args:
|
||||||
|
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
||||||
|
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
||||||
|
across packs.
|
||||||
|
Returns:
|
||||||
|
Tensor: Document IDs of shape (batch_size, max_seq_len).
|
||||||
|
"""
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
batch_document_ids = []
|
||||||
|
for sample_idx in range(batch_size):
|
||||||
|
# We assume seq lens sum to max seq lens, so document_ids should be of
|
||||||
|
# shape (max_seq_len, )
|
||||||
|
document_ids = torch.cat(
|
||||||
|
[
|
||||||
|
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
|
||||||
|
for i, seq_len in enumerate(seq_lens[sample_idx])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
batch_document_ids.append(document_ids)
|
||||||
|
batch_document_ids = torch.stack(batch_document_ids)
|
||||||
|
return batch_document_ids
|
||||||
|
|||||||
Reference in New Issue
Block a user