wip for multipack pretraining
This commit is contained in:
@@ -698,6 +698,24 @@ def get_dataset_wrapper(
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def encode_packed_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||
):
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
res = tokenizer(
|
||||
examples,
|
||||
truncation=True,
|
||||
max_length=max_tokens,
|
||||
add_special_tokens=True,
|
||||
return_overflowing_tokens=True,
|
||||
stride=256,
|
||||
)
|
||||
# convert to a dataset.from_list
|
||||
# use a dataloader and multipack batch sampler to pack the data
|
||||
pass
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||
) -> Dict[str, List]:
|
||||
@@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
batch_size=10_000,
|
||||
input_columns="text",
|
||||
# remove all the existing columns after mapping since they end up having
|
||||
# a different length than the encoded/tokenized column
|
||||
|
||||
Reference in New Issue
Block a user