WIP for axolotl trainer
This commit is contained in:
0
src/axolotl/__init__.py
Normal file
0
src/axolotl/__init__.py
Normal file
50
src/axolotl/convert.py
Normal file
50
src/axolotl/convert.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
class FileReader:
|
||||
def read(self, file_path):
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
class FileWriter:
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
|
||||
def write(self, content):
|
||||
with open(self.file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
class StdoutWriter:
|
||||
def write(self, content):
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
class JsonParser:
|
||||
def parse(self, content):
|
||||
return json.loads(content)
|
||||
|
||||
|
||||
class JsonlSerializer:
|
||||
def serialize(self, data):
|
||||
lines = [json.dumps(item) for item in data]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class JsonToJsonlConverter:
|
||||
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
|
||||
self.file_reader = file_reader
|
||||
self.file_writer = file_writer
|
||||
self.json_parser = json_parser
|
||||
self.jsonl_serializer = jsonl_serializer
|
||||
|
||||
def convert(self, input_file_path, output_file_path):
|
||||
content = self.file_reader.read(input_file_path)
|
||||
data = self.json_parser.parse(content)
|
||||
jsonl_content = self.jsonl_serializer.serialize(data)
|
||||
self.file_writer.write(jsonl_content)
|
||||
|
||||
|
||||
86
src/axolotl/datasets.py
Normal file
86
src/axolotl/datasets.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
|
||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||
# the collators later on to pad the datasets
|
||||
|
||||
|
||||
class TokenizedPromptDataset(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
|
||||
|
||||
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
datasets,
|
||||
infinite=False,
|
||||
seq_length=2048,
|
||||
num_of_sequences=1024,
|
||||
chars_per_token=3.6,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
|
||||
self.datasets: List[IterableDataset] = datasets
|
||||
self.seq_length = seq_length
|
||||
self.infinite = infinite
|
||||
self.current_size = 0
|
||||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.datasets)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator))
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
if self.infinite:
|
||||
iterator = iter(self.datasets)
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
self.current_size += 1
|
||||
yield {
|
||||
"input_ids": torch.LongTensor(input_ids),
|
||||
"labels": torch.LongTensor(input_ids),
|
||||
"attention_masks": torch.LongTensor(input_ids),
|
||||
}
|
||||
83
src/axolotl/prompt_tokenizers.py
Normal file
83
src/axolotl/prompt_tokenizers.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import abc
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs: bool = False,
|
||||
sequence_len: int = 2048,
|
||||
):
|
||||
self.prompter = prompter
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.train_on_inputs = train_on_inputs
|
||||
self.sequence_len = sequence_len
|
||||
|
||||
@abc.abstractmethod
|
||||
def tokenize_prompt(self, prompt):
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def tokenize_prompt(self, prompt):
|
||||
full_prompt = self._tokenize_full_prompt(prompt)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = self.prompter.generate_prompt(
|
||||
prompt["instruction"], prompt["input"]
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _tokenize_full_prompt(self, prompt):
|
||||
return self.prompter.generate_prompt(
|
||||
prompt["instruction"],
|
||||
prompt["input"],
|
||||
prompt["output"],
|
||||
)
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
|
||||
def _tokenize_full_prompt(self, prompt):
|
||||
return self.prompter.generate_prompt(
|
||||
prompt["instruction"],
|
||||
prompt["input"],
|
||||
prompt["response"],
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def tokenize_prompt(self, prompt):
|
||||
pass
|
||||
10
src/axolotl/prompters.py
Normal file
10
src/axolotl/prompters.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class AlpacaPrompter:
|
||||
pass
|
||||
|
||||
|
||||
class ShareGPTPrompter:
|
||||
pass
|
||||
|
||||
|
||||
class GPTeacherPrompter:
|
||||
pass
|
||||
Reference in New Issue
Block a user