Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
effb281b24 wip for multipack pretraining 2023-11-25 17:12:20 -05:00
Wing Lian
6a4562ac08 update datasets version to cut down the warnings due to pyarrow arg change (#897)
* update datasets to cut down the warnings

* set versions for tokenizers and gradio

* upgrade transformers to latest version
2023-11-25 16:30:00 -05:00
2 changed files with 23 additions and 3 deletions

View File

@@ -2,14 +2,15 @@
auto-gptq==0.5.1
packaging
peft==0.6.0
transformers==4.35.1
transformers==4.35.2
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.24.1
deepspeed
addict
fire
PyYAML>=6.0
datasets>=2.14.0
datasets>=2.15.0
flash-attn==2.3.3
sentencepiece
wandb
@@ -29,7 +30,7 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
gradio
gradio==3.50.2
tensorboard
# remote filesystems

View File

@@ -698,6 +698,24 @@ def get_dataset_wrapper(
return dataset_wrapper, dataset_prompter
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
):
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_tokens,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
# convert to a dataset.from_list
# use a dataloader and multipack batch sampler to pack the data
pass
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
@@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column