Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
da154e6d56 support for json data as completion 2023-11-25 16:05:04 -05:00
3 changed files with 30 additions and 24 deletions

View File

@@ -2,15 +2,14 @@
auto-gptq==0.5.1 auto-gptq==0.5.1
packaging packaging
peft==0.6.0 peft==0.6.0
transformers==4.35.2 transformers==4.35.1
tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.24.1 accelerate==0.24.1
deepspeed deepspeed
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets>=2.15.0 datasets>=2.14.0
flash-attn==2.3.3 flash-attn==2.3.3
sentencepiece sentencepiece
wandb wandb
@@ -30,7 +29,7 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
gradio==3.50.2 gradio
tensorboard tensorboard
# remote filesystems # remote filesystems

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return next(iter(self.prompter.build_prompt(instruction, input, response))) return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
"""
Strategy to return the stringified JSON of the entire row as the training data
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
json.dumps(prompt),
"",
"",
)
class CompletionPrompter: class CompletionPrompter:
""" """
Prompter for completion Prompter for completion
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy( strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(), CompletionPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, True,
cfg.sequence_len, cfg.sequence_len,
max_length=cfg.sequence_len * 64, max_length=cfg.sequence_len * 64,
) )
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat.field = ds_cfg["field"] strat.field = ds_cfg["field"]
return strat return strat
def load_json(tokenizer, cfg):
strat = CompletionJSONPromptTokenizationStrategy(
CompletionPrompter(),
tokenizer,
True,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -698,24 +698,6 @@ def get_dataset_wrapper(
return dataset_wrapper, dataset_prompter return dataset_wrapper, dataset_prompter
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
):
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_tokens,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
# convert to a dataset.from_list
# use a dataloader and multipack batch sampler to pack the data
pass
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]: ) -> Dict[str, List]:
@@ -831,7 +813,6 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
dataset = dataset.map( dataset = dataset.map(
encode, encode,
batched=True, batched=True,
batch_size=10_000,
input_columns="text", input_columns="text",
# remove all the existing columns after mapping since they end up having # remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column # a different length than the encoded/tokenized column