make it work with pythia in the cloud

This commit is contained in:
Wing Lian
2023-04-14 07:24:55 -04:00
parent ce24f5e246
commit 8d959a7e26
7 changed files with 352 additions and 70 deletions

View File

@@ -44,6 +44,7 @@ class JsonToJsonlConverter:
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
jsonl_content = self.jsonl_serializer.serialize(data)
self.file_writer.write(jsonl_content)

View File

@@ -2,7 +2,7 @@ from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
# We want this to be a wrapper for an existing dataset that we have loaded
@@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset):
def __iter__(self):
iterator = iter(self.dataset)
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
# Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
except InvalidDataException:
pass
class ConstantLengthDataset(IterableDataset):
@@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset):
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
datasets,
infinite=False,
seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
def __iter__(self):
iterator = iter(self.datasets)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
for dataset in self.datasets:
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
buffer.append(next(iterator))
buffer_len += len(buffer[-1])
example = next(iterator)
except StopIteration:
if self.infinite:
iterator = iter(self.datasets)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
"attention_masks": torch.LongTensor(input_ids),
}
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
if example:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
labels_with_concat = torch.tensor(labels, dtype=torch.long)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer_len += len(input_ids)

View File

@@ -9,6 +9,10 @@ LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class InvalidDataException(Exception):
pass
class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
@@ -32,7 +36,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
full_prompt = self._tokenize_full_prompt(prompt)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.generate_prompt(
user_prompt = self.prompter.build_prompt(
prompt["instruction"], prompt["input"]
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
@@ -43,7 +47,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"],
prompt["output"],
@@ -71,7 +75,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"],
prompt["response"],
@@ -80,4 +84,7 @@ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
pass
try:
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
except (KeyError, AssertionError) as e:
raise InvalidDataException(str(e))

View File

@@ -1,10 +1,160 @@
import copy
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any, Union
IGNORE_TOKEN_ID = -100
class AlpacaPrompter:
pass
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
response_split = "### Response:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
output: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(
instruction=instruction, input=input
)
else:
res = self.prompt_no_input.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
return res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class GPTeacherPrompter(AlpacaPrompter):
...
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
def get_prompt(self):
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
class ShareGPTPrompter:
pass
def build_prompt(
self,
source,
tokenizer
):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
conv = conv_vicuna_v1_1.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
class GPTeacherPrompter:
pass
try:
# Apply prompt templates
if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as e:
# sometimes there is a bing or system chat
raise e
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"])
conversation = conv.get_prompt()
# Tokenize conversations
tokenized_result = tokenizer(
conversation,
truncation=True,
max_length=2048, # FIXME
padding=False,
return_tensors=None,
)
target = copy.deepcopy(tokenized_result["input_ids"])
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
rounds = conversation.split(conv.sep2)
cur_len = 1
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou)["input_ids"])
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
cur_len += round_len
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
return dict(input_ids=tokenized_result["input_ids"], labels=target,
attention_mask=attention_mask)