Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
effb281b24 wip for multipack pretraining 2023-11-25 17:12:20 -05:00

View File

@@ -698,6 +698,24 @@ def get_dataset_wrapper(
return dataset_wrapper, dataset_prompter
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
):
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_tokens,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
# convert to a dataset.from_list
# use a dataloader and multipack batch sampler to pack the data
pass
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
@@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column