From effb281b2453a7cd3583118cf411089e004d7d56 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 25 Nov 2023 17:12:20 -0500 Subject: [PATCH] wip for multipack pretraining --- src/axolotl/utils/data.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c41d16fe..bca21076e 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -698,6 +698,24 @@ def get_dataset_wrapper( return dataset_wrapper, dataset_prompter +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] +): + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_tokens, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + # convert to a dataset.from_list + # use a dataloader and multipack batch sampler to pack the data + pass + + def encode_pretraining( tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] ) -> Dict[str, List]: @@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): dataset = dataset.map( encode, batched=True, + batch_size=10_000, input_columns="text", # remove all the existing columns after mapping since they end up having # a different length than the encoded/tokenized column