WIP for axolotl trainer

This commit is contained in:
Wing Lian
2023-04-14 00:20:05 -04:00
parent e9da4b9a30
commit ce24f5e246
16 changed files with 497 additions and 1 deletions

0
src/axolotl/__init__.py Normal file
View File

50
src/axolotl/convert.py Normal file
View File

@@ -0,0 +1,50 @@
import json
import sys
class FileReader:
def read(self, file_path):
with open(file_path, "r") as file:
return file.read()
class FileWriter:
def __init__(self, file_path):
self.file_path = file_path
def write(self, content):
with open(self.file_path, "w") as file:
file.write(content)
class StdoutWriter:
def write(self, content):
sys.stdout.write(content)
sys.stdout.write("\n")
class JsonParser:
def parse(self, content):
return json.loads(content)
class JsonlSerializer:
def serialize(self, data):
lines = [json.dumps(item) for item in data]
return "\n".join(lines)
class JsonToJsonlConverter:
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
self.file_reader = file_reader
self.file_writer = file_writer
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
jsonl_content = self.jsonl_serializer.serialize(data)
self.file_writer.write(jsonl_content)

86
src/axolotl/datasets.py Normal file
View File

@@ -0,0 +1,86 @@
from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
class TokenizedPromptDataset(IterableDataset):
def __init__(
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
):
self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset
def __iter__(self):
iterator = iter(self.dataset)
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
datasets,
infinite=False,
seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
def __iter__(self):
iterator = iter(self.datasets)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.datasets)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
"attention_masks": torch.LongTensor(input_ids),
}

View File

@@ -0,0 +1,83 @@
import abc
from transformers import PreTrainedTokenizer
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
):
self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
@abc.abstractmethod
def tokenize_prompt(self, prompt):
pass
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
full_prompt = self._tokenize_full_prompt(prompt)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.generate_prompt(
prompt["instruction"], prompt["input"]
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
prompt["instruction"],
prompt["input"],
prompt["output"],
)
def _tokenize(self, prompt, add_eos_token=True):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
prompt["instruction"],
prompt["input"],
prompt["response"],
)
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
pass

10
src/axolotl/prompters.py Normal file
View File

@@ -0,0 +1,10 @@
class AlpacaPrompter:
pass
class ShareGPTPrompter:
pass
class GPTeacherPrompter:
pass