WIP for axolotl trainer

This commit is contained in:
Wing Lian
2023-04-14 00:20:05 -04:00
parent e9da4b9a30
commit ce24f5e246
16 changed files with 497 additions and 1 deletions

14
.editorconfig Normal file
View File

@@ -0,0 +1,14 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.py]
indent_style = space
indent_size = 4
[**.yml]
indent_style = space
indent_size = 2

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
**/axolotl.egg-info
**/__pycache__
.idea

View File

@@ -1,6 +1,13 @@
# Axolotl
### You know you're going to axolotl questions
#### You know you're going to axolotl questions
### Converting JSON data files to JSONL
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
```

View File

@@ -0,0 +1,37 @@
base_model: EleutherAI/pythia-1.4b-deduped
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
datasets:
- path: ./data/alpaca_data_gpt4.jsonl
type: alpaca
- path: ./data/vicuna_cleaned.jsonl
type: sharegpt
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
val_set_size: 0.05
adapter: lora
sequence_len: 2048
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
wandb_project:
wandb_watch:
wandb:run_name:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 128
micro_batch_size: 8
num_epochs: 5
learning_rate: 0.0003
train_on_inputs: false
bf16: True
fp16: True
resume_from_checkpoint:
local_rank:
deepspeed:

8
data/README.md Normal file
View File

@@ -0,0 +1,8 @@
```shell
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o raw/alpaca_data_gpt4.json
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o raw/vicuna_cleaned.json
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o raw/gpt4-instruct-similarity-0.6-dataset.json
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o raw/roleplay-similarity_0.6-instruct-dataset.json
```

1
data/raw/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
**

3
pyproject.toml Normal file
View File

@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

6
requirements.txt Normal file
View File

@@ -0,0 +1,6 @@
git+https://github.com/huggingface/transformers.git
git+https://github.com/huggingface/peft.git
attrdict
fire
PyYAML==6.0
black

View File

@@ -0,0 +1,36 @@
import os
import sys
from pathlib import Path
import fire
from typing import Optional
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.insert(0, src_dir)
from axolotl.convert import *
def main(
input: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
file_reader = FileReader()
if to_stdout or output is None:
writer = StdoutWriter()
else:
writer = FileWriter(output)
json_parser = JsonParser()
jsonl_serializer = JsonlSerializer()
converter = JsonToJsonlConverter(
file_reader, writer, json_parser, jsonl_serializer
)
converter.convert(input, output)
if __name__ == "__main__":
fire.Fire(main)

129
scripts/finetune.py Normal file
View File

@@ -0,0 +1,129 @@
import os
import sys
from pathlib import Path
import fire
import torch
import transformers
import yaml
from attrdict import AttrDict
from datasets import load_dataset, IterableDataset
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.insert(0, src_dir)
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
def setup_wandb_env_vars(cfg):
if len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
if adapter != "lora":
raise NotImplementedError(f"{adapter} peft adapter not available")
try:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)
except:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)
try:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.__class__.__name__ == "LlamaTokenizer":
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if cfg.load_in_8bit:
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules,
lora_dropout=cfg.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
if cfg.ddp:
model.to(f"cuda:{cfg.local_rank}")
# TODO resume_from_checkpoint handling
model.print_trainable_parameters()
return model, tokenizer
def train(
config: Path = Path('configs/pythia_1_2B_alpaca.yml'),
**kwargs,
):
# load the config from the yaml file
with open(config, 'r') as f:
cfg: AttrDict = AttrDict(yaml.load(f))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
for k, v in enumerate(kwargs):
if k in cfg:
cfg.k = v
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.device_map = "auto"
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.ddp = cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size
setup_wandb_env_vars(cfg)
# Load the model and tokenizer
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
datasets = []
for d in cfg.datasets:
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d.type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d.type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
if __name__ == "__main__":
fire.Fire(train)

23
setup.cfg Normal file
View File

@@ -0,0 +1,23 @@
[metadata]
name = axolotl
version = 0.1.0
description = You know you're going to axolotl questions
author = Wing Lian
author_email = wing.lian@gmail.com
license = MIT
[options]
package_dir =
=src
packages = find:
install_requires =
transformers @ git+https://github.com/huggingface/transformers.git@main
peft @ git+https://github.com/huggingface/peft.git@main
attrdict
fire
PyYAML == 6.0
black
[options.packages.find]
where = src

0
src/axolotl/__init__.py Normal file
View File

50
src/axolotl/convert.py Normal file
View File

@@ -0,0 +1,50 @@
import json
import sys
class FileReader:
def read(self, file_path):
with open(file_path, "r") as file:
return file.read()
class FileWriter:
def __init__(self, file_path):
self.file_path = file_path
def write(self, content):
with open(self.file_path, "w") as file:
file.write(content)
class StdoutWriter:
def write(self, content):
sys.stdout.write(content)
sys.stdout.write("\n")
class JsonParser:
def parse(self, content):
return json.loads(content)
class JsonlSerializer:
def serialize(self, data):
lines = [json.dumps(item) for item in data]
return "\n".join(lines)
class JsonToJsonlConverter:
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
self.file_reader = file_reader
self.file_writer = file_writer
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
jsonl_content = self.jsonl_serializer.serialize(data)
self.file_writer.write(jsonl_content)

86
src/axolotl/datasets.py Normal file
View File

@@ -0,0 +1,86 @@
from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
class TokenizedPromptDataset(IterableDataset):
def __init__(
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
):
self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset
def __iter__(self):
iterator = iter(self.dataset)
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
datasets,
infinite=False,
seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
def __iter__(self):
iterator = iter(self.datasets)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.datasets)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
"attention_masks": torch.LongTensor(input_ids),
}

View File

@@ -0,0 +1,83 @@
import abc
from transformers import PreTrainedTokenizer
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
):
self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
@abc.abstractmethod
def tokenize_prompt(self, prompt):
pass
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
full_prompt = self._tokenize_full_prompt(prompt)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.generate_prompt(
prompt["instruction"], prompt["input"]
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
prompt["instruction"],
prompt["input"],
prompt["output"],
)
def _tokenize(self, prompt, add_eos_token=True):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
prompt["instruction"],
prompt["input"],
prompt["response"],
)
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
pass

10
src/axolotl/prompters.py Normal file
View File

@@ -0,0 +1,10 @@
class AlpacaPrompter:
pass
class ShareGPTPrompter:
pass
class GPTeacherPrompter:
pass