diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c41d16fe..bca21076e 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -698,6 +698,24 @@ def get_dataset_wrapper( return dataset_wrapper, dataset_prompter +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] +): + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_tokens, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + # convert to a dataset.from_list + # use a dataloader and multipack batch sampler to pack the data + pass + + def encode_pretraining( tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] ) -> Dict[str, List]: @@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): dataset = dataset.map( encode, batched=True, + batch_size=10_000, input_columns="text", # remove all the existing columns after mapping since they end up having # a different length than the encoded/tokenized column