make it work with pythia in the cloud

This commit is contained in:
Wing Lian
2023-04-14 07:24:55 -04:00
parent ce24f5e246
commit 8d959a7e26
7 changed files with 352 additions and 70 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
data/*.jsonl filter=lfs diff=lfs merge=lfs -text

View File

@@ -3,35 +3,36 @@ model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
datasets:
- path: ./data/alpaca_data_gpt4.jsonl
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: ./data/vicuna_cleaned.jsonl
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
val_set_size: 0.05
adapter: lora
sequence_len: 2048
lora_r: 16
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
wandb_project:
- query_key_value
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: pythia-1.4b-lora
wandb_watch:
wandb:run_name:
wandb_run_name:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 128
micro_batch_size: 8
batch_size: 32
micro_batch_size: 4
num_epochs: 5
learning_rate: 0.0003
train_on_inputs: false
group_by_length: false
bf16: True
fp16: True
tf32: True
resume_from_checkpoint:
local_rank:
deepspeed:

View File

@@ -1,26 +1,32 @@
import math
import os
import signal
import sys
from pathlib import Path
import bitsandbytes as bnb
import fire
import torch
import transformers
import yaml
from attrdict import AttrDict
from datasets import load_dataset, IterableDataset
from datasets import load_dataset, IterableDataset, Dataset
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_int8_training, get_peft_model_state_dict,
)
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
# add src to the pythonpath so we don't need to pip install this
from transformers.trainer_pt_utils import get_parameter_names
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.insert(0, src_dir)
from axolotl.datasets import TokenizedPromptDataset
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
@@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg):
if len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if len(cfg.wandb_watch) > 0:
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if len(cfg.wandb_log_model) > 0:
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
@@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
if tokenizer.__class__.__name__ == "LlamaTokenizer":
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.load_in_8bit:
model = prepare_model_for_int8_training(model)
@@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
bias="none",
task_type="CAUSAL_LM",
)
@@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
# TODO resume_from_checkpoint handling
model.print_trainable_parameters()
return model, tokenizer
return model, tokenizer, lora_config
def train(
@@ -88,7 +99,7 @@ def train(
):
# load the config from the yaml file
with open(config, 'r') as f:
cfg: AttrDict = AttrDict(yaml.load(f))
cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
for k, v in enumerate(kwargs):
@@ -107,23 +118,116 @@ def train(
setup_wandb_env_vars(cfg)
# Load the model and tokenizer
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
datasets = []
for d in cfg.datasets:
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None)
if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len)
constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
test_size=cfg.val_set_size, shuffle=True, seed=42
)
print(constant_len_dataset)
train_dataset = constant_len_dataset["train"]
eval_dataset = constant_len_dataset["test"]
total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size))
warmup_steps = min(int(0.03 * total_num_steps), 100)
logging_steps = min(int(0.005 * total_num_steps), 10)
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
warmup_steps=warmup_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
bf16=cfg.bf16,
tf32=cfg.tf32,
logging_steps=logging_steps,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_name if cfg.use_wandb else None,
)
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": training_args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
adam_bnb_optim = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
adam_bnb_optim,
training_args.warmup_steps,
total_num_steps,
)
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
optimizers=(adam_bnb_optim, lr_scheduler),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
signal.signal(signal.SIGINT, lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
exit(0)
))
# go ahead and presave the adapter config
lora_config.save_pretrained(cfg.output_dir)
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
model.save_pretrained(cfg.output_dir)
if __name__ == "__main__":
fire.Fire(train)

View File

@@ -44,6 +44,7 @@ class JsonToJsonlConverter:
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
jsonl_content = self.jsonl_serializer.serialize(data)
self.file_writer.write(jsonl_content)

View File

@@ -2,7 +2,7 @@ from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
# We want this to be a wrapper for an existing dataset that we have loaded
@@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset):
def __iter__(self):
iterator = iter(self.dataset)
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
# Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
except InvalidDataException:
pass
class ConstantLengthDataset(IterableDataset):
@@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset):
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
datasets,
infinite=False,
seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
def __iter__(self):
iterator = iter(self.datasets)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
for dataset in self.datasets:
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
buffer.append(next(iterator))
buffer_len += len(buffer[-1])
example = next(iterator)
except StopIteration:
if self.infinite:
iterator = iter(self.datasets)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
"attention_masks": torch.LongTensor(input_ids),
}
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
if example:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
labels_with_concat = torch.tensor(labels, dtype=torch.long)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer_len += len(input_ids)

View File

@@ -9,6 +9,10 @@ LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class InvalidDataException(Exception):
pass
class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
@@ -32,7 +36,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
full_prompt = self._tokenize_full_prompt(prompt)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.generate_prompt(
user_prompt = self.prompter.build_prompt(
prompt["instruction"], prompt["input"]
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
@@ -43,7 +47,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"],
prompt["output"],
@@ -71,7 +75,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt(
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"],
prompt["response"],
@@ -80,4 +84,7 @@ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
pass
try:
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
except (KeyError, AssertionError) as e:
raise InvalidDataException(str(e))

View File

@@ -1,10 +1,160 @@
import copy
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any, Union
IGNORE_TOKEN_ID = -100
class AlpacaPrompter:
pass
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
response_split = "### Response:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
output: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(
instruction=instruction, input=input
)
else:
res = self.prompt_no_input.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
return res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class GPTeacherPrompter(AlpacaPrompter):
...
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
def get_prompt(self):
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
class ShareGPTPrompter:
pass
def build_prompt(
self,
source,
tokenizer
):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
conv = conv_vicuna_v1_1.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
class GPTeacherPrompter:
pass
try:
# Apply prompt templates
if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as e:
# sometimes there is a bing or system chat
raise e
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"])
conversation = conv.get_prompt()
# Tokenize conversations
tokenized_result = tokenizer(
conversation,
truncation=True,
max_length=2048, # FIXME
padding=False,
return_tensors=None,
)
target = copy.deepcopy(tokenized_result["input_ids"])
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
rounds = conversation.split(conv.sep2)
cur_len = 1
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou)["input_ids"])
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
cur_len += round_len
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
return dict(input_ids=tokenized_result["input_ids"], labels=target,
attention_mask=attention_mask)