Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
effb281b24 wip for multipack pretraining 2023-11-25 17:12:20 -05:00
Wing Lian
6a4562ac08 update datasets version to cut down the warnings due to pyarrow arg change (#897)
* update datasets to cut down the warnings

* set versions for tokenizers and gradio

* upgrade transformers to latest version
2023-11-25 16:30:00 -05:00
2 changed files with 23 additions and 3 deletions

View File

@@ -2,14 +2,15 @@
auto-gptq==0.5.1 auto-gptq==0.5.1
packaging packaging
peft==0.6.0 peft==0.6.0
transformers==4.35.1 transformers==4.35.2
tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.24.1 accelerate==0.24.1
deepspeed deepspeed
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets>=2.14.0 datasets>=2.15.0
flash-attn==2.3.3 flash-attn==2.3.3
sentencepiece sentencepiece
wandb wandb
@@ -29,7 +30,7 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
gradio gradio==3.50.2
tensorboard tensorboard
# remote filesystems # remote filesystems

View File

@@ -698,6 +698,24 @@ def get_dataset_wrapper(
return dataset_wrapper, dataset_prompter return dataset_wrapper, dataset_prompter
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
):
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_tokens,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
# convert to a dataset.from_list
# use a dataloader and multipack batch sampler to pack the data
pass
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]: ) -> Dict[str, List]:
@@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
dataset = dataset.map( dataset = dataset.map(
encode, encode,
batched=True, batched=True,
batch_size=10_000,
input_columns="text", input_columns="text",
# remove all the existing columns after mapping since they end up having # remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column # a different length than the encoded/tokenized column