From 8d959a7e2639f3ddfc6f1fa1183d7a16acf2764a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 14 Apr 2023 07:24:55 -0400 Subject: [PATCH] make it work with pythia in the cloud --- .gitattributes | 1 + configs/pythia_1_2B_alpaca.yml | 25 ++--- scripts/finetune.py | 128 ++++++++++++++++++++++--- src/axolotl/convert.py | 1 + src/axolotl/datasets.py | 94 ++++++++++-------- src/axolotl/prompt_tokenizers.py | 15 ++- src/axolotl/prompters.py | 158 ++++++++++++++++++++++++++++++- 7 files changed, 352 insertions(+), 70 deletions(-) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..7b52c8631 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +data/*.jsonl filter=lfs diff=lfs merge=lfs -text diff --git a/configs/pythia_1_2B_alpaca.yml b/configs/pythia_1_2B_alpaca.yml index ca91f2aab..60d8b90c4 100644 --- a/configs/pythia_1_2B_alpaca.yml +++ b/configs/pythia_1_2B_alpaca.yml @@ -3,35 +3,36 @@ model_type: GPTNeoXForCausalLM tokenizer_type: AutoTokenizer load_in_8bit: true datasets: - - path: ./data/alpaca_data_gpt4.jsonl + - path: data/alpaca_data_gpt4.jsonl type: alpaca - - path: ./data/vicuna_cleaned.jsonl + - path: data/vicuna_cleaned.jsonl type: sharegpt - - path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl + - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl type: gpteacher - - path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl + - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl type: gpteacher val_set_size: 0.05 adapter: lora sequence_len: 2048 -lora_r: 16 +lora_r: 8 lora_alpha: 32 lora_dropout: 0.05 lora_target_modules: - - q_proj - - v_proj -wandb_project: + - query_key_value +lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific +wandb_project: pythia-1.4b-lora wandb_watch: -wandb:run_name: +wandb_run_name: wandb_log_model: checkpoint output_dir: ./lora-alpaca -batch_size: 128 -micro_batch_size: 8 +batch_size: 32 +micro_batch_size: 4 num_epochs: 5 learning_rate: 0.0003 train_on_inputs: false +group_by_length: false bf16: True -fp16: True +tf32: True resume_from_checkpoint: local_rank: deepspeed: diff --git a/scripts/finetune.py b/scripts/finetune.py index 9e5c61091..bf6c95bb4 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,26 +1,32 @@ +import math import os +import signal import sys from pathlib import Path +import bitsandbytes as bnb import fire import torch import transformers import yaml from attrdict import AttrDict -from datasets import load_dataset, IterableDataset +from datasets import load_dataset, IterableDataset, Dataset from peft import ( LoraConfig, get_peft_model, - prepare_model_for_int8_training, + prepare_model_for_int8_training, get_peft_model_state_dict, ) +from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer # add src to the pythonpath so we don't need to pip install this +from transformers.trainer_pt_utils import get_parameter_names + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) src_dir = os.path.join(project_root, 'src') sys.path.insert(0, src_dir) -from axolotl.datasets import TokenizedPromptDataset +from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \ LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter @@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg): if len(cfg.wandb_project) > 0: os.environ["WANDB_PROJECT"] = cfg.wandb_project cfg.use_wandb = True - if len(cfg.wandb_watch) > 0: + if cfg.wandb_watch and len(cfg.wandb_watch) > 0: os.environ["WANDB_WATCH"] = cfg.wandb_watch - if len(cfg.wandb_log_model) > 0: + if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model @@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): if tokenizer.__class__.__name__ == "LlamaTokenizer": tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN + if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + if cfg.load_in_8bit: model = prepare_model_for_int8_training(model) @@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): lora_alpha=cfg.lora_alpha, target_modules=cfg.lora_target_modules, lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, bias="none", task_type="CAUSAL_LM", ) @@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): # TODO resume_from_checkpoint handling model.print_trainable_parameters() - return model, tokenizer + return model, tokenizer, lora_config def train( @@ -88,7 +99,7 @@ def train( ): # load the config from the yaml file with open(config, 'r') as f: - cfg: AttrDict = AttrDict(yaml.load(f)) + cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value for k, v in enumerate(kwargs): @@ -107,23 +118,116 @@ def train( setup_wandb_env_vars(cfg) # Load the model and tokenizer - model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter) + model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter) datasets = [] for d in cfg.datasets: - ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None) + ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None) if d.type == "alpaca": ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) elif d.type == "gpteacher": ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) elif d.type == "sharegpt": ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len) + constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split( + test_size=cfg.val_set_size, shuffle=True, seed=42 + ) + print(constant_len_dataset) + train_dataset = constant_len_dataset["train"] + eval_dataset = constant_len_dataset["test"] + + total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)) + warmup_steps = min(int(0.03 * total_num_steps), 100) + logging_steps = min(int(0.005 * total_num_steps), 10) + save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) + + training_args = transformers.TrainingArguments( + per_device_train_batch_size=cfg.micro_batch_size, + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + warmup_steps=warmup_steps, + num_train_epochs=cfg.num_epochs, + learning_rate=cfg.learning_rate, + bf16=cfg.bf16, + tf32=cfg.tf32, + logging_steps=logging_steps, + evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", + save_strategy="steps", + eval_steps=eval_steps if cfg.val_set_size > 0 else None, + save_steps=save_steps, + output_dir=cfg.output_dir, + save_total_limit=3, + load_best_model_at_end=True if cfg.val_set_size > 0 else False, + ddp_find_unused_parameters=False if cfg.ddp else None, + group_by_length=cfg.group_by_length, + report_to="wandb" if cfg.use_wandb else None, + run_name=cfg.wandb_run_name if cfg.use_wandb else None, + ) + + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": training_args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if n not in decay_parameters], + "weight_decay": 0.0, + }, + ] + + adam_bnb_optim = bnb.optim.Adam8bit( + optimizer_grouped_parameters, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + lr=training_args.learning_rate, + ) + + lr_scheduler = transformers.get_cosine_schedule_with_warmup( + adam_bnb_optim, + training_args.warmup_steps, + total_num_steps, + ) + + trainer = transformers.Trainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + args=training_args, + optimizers=(adam_bnb_optim, lr_scheduler), + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + ) + model.config.use_cache = False + + old_state_dict = model.state_dict + model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict( + self, old_state_dict() + ) + ).__get__(model, type(model)) + + if torch.__version__ >= "2" and sys.platform != "win32": + model = torch.compile(model) + + signal.signal(signal.SIGINT, lambda signal, frame: ( + model.save_pretrained(cfg.output_dir), + exit(0) + )) + + # go ahead and presave the adapter config + lora_config.save_pretrained(cfg.output_dir) + trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) + + model.save_pretrained(cfg.output_dir) if __name__ == "__main__": fire.Fire(train) diff --git a/src/axolotl/convert.py b/src/axolotl/convert.py index d4c2ccf2f..7a1c98d97 100644 --- a/src/axolotl/convert.py +++ b/src/axolotl/convert.py @@ -44,6 +44,7 @@ class JsonToJsonlConverter: def convert(self, input_file_path, output_file_path): content = self.file_reader.read(input_file_path) data = self.json_parser.parse(content) + # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations jsonl_content = self.jsonl_serializer.serialize(data) self.file_writer.write(jsonl_content) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index f805e92ad..0e583502c 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -2,7 +2,7 @@ from typing import List import torch from datasets import IterableDataset -from .prompt_tokenizers import PromptTokenizingStrategy +from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException # We want this to be a wrapper for an existing dataset that we have loaded @@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset): def __iter__(self): iterator = iter(self.dataset) - yield self.prompt_tokenizer.tokenize_prompt(next(iterator)) + # Loop through the entire dataset + for example in iterator: + try: + yield self.prompt_tokenizer.tokenize_prompt(example) + except InvalidDataException: + pass class ConstantLengthDataset(IterableDataset): @@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset): Args: tokenizer (Tokenizer): The processor used for proccessing the data. dataset (dataset.Dataset): Dataset with text files. - infinite (bool): If True the iterator is reset after dataset reaches end else stops. seq_length (int): Length of token sequences to return. - chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. """ def __init__( self, tokenizer, datasets, - infinite=False, seq_length=2048, - num_of_sequences=1024, - chars_per_token=3.6, ): self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id + self.concat_token_id = tokenizer.eos_token_id self.datasets: List[IterableDataset] = datasets self.seq_length = seq_length - self.infinite = infinite - self.current_size = 0 - self.max_buffer_size = seq_length * chars_per_token * num_of_sequences def __iter__(self): - iterator = iter(self.datasets) - more_examples = True - while more_examples: - buffer, buffer_len = [], 0 - while True: - if buffer_len >= self.max_buffer_size: - break + buffer = {"input_ids": [], "attention_mask": [], "labels": []} + buffer_len = 0 + for dataset in self.datasets: + iterator = iter(dataset) + more_examples = True + while more_examples: try: - buffer.append(next(iterator)) - buffer_len += len(buffer[-1]) + example = next(iterator) except StopIteration: - if self.infinite: - iterator = iter(self.datasets) - else: - more_examples = False - break - tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] - all_token_ids = [] - for tokenized_input in tokenized_inputs: - all_token_ids.extend(tokenized_input + [self.concat_token_id]) - for i in range(0, len(all_token_ids), self.seq_length): - input_ids = all_token_ids[i : i + self.seq_length] - if len(input_ids) == self.seq_length: - self.current_size += 1 - yield { - "input_ids": torch.LongTensor(input_ids), - "labels": torch.LongTensor(input_ids), - "attention_masks": torch.LongTensor(input_ids), - } + more_examples = False + example = None + + add_concat_token = False + if example: + example_len = len(example["input_ids"]) + add_concat_token = example["input_ids"][-1] != self.concat_token_id + else: + example_len = 0 + + if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length: + if buffer["input_ids"]: + input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length] + attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length] + labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] + yield { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + } + buffer = {"input_ids": [], "attention_mask": [], "labels": []} + buffer_len = 0 + + if example: + input_ids = example["input_ids"] + attention_mask = example["attention_mask"] + labels = example["labels"] + + if add_concat_token: + input_ids.append(self.concat_token_id) + attention_mask.append(1) + labels.append(self.concat_token_id) + + input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long) + attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long) + labels_with_concat = torch.tensor(labels, dtype=torch.long) + + buffer["input_ids"].append(input_ids_with_concat) + buffer["attention_mask"].append(attention_mask_with_concat) + buffer["labels"].append(labels_with_concat) + buffer_len += len(input_ids) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 1748597ba..589dd0e2a 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -9,6 +9,10 @@ LLAMA_DEFAULT_BOS_TOKEN = "" LLAMA_DEFAULT_UNK_TOKEN = "" +class InvalidDataException(Exception): + pass + + class PromptTokenizingStrategy(abc.ABC): def __init__( self, @@ -32,7 +36,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy): full_prompt = self._tokenize_full_prompt(prompt) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: - user_prompt = self.prompter.generate_prompt( + user_prompt = self.prompter.build_prompt( prompt["instruction"], prompt["input"] ) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) @@ -43,7 +47,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt def _tokenize_full_prompt(self, prompt): - return self.prompter.generate_prompt( + return self.prompter.build_prompt( prompt["instruction"], prompt["input"], prompt["output"], @@ -71,7 +75,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy): class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy): def _tokenize_full_prompt(self, prompt): - return self.prompter.generate_prompt( + return self.prompter.build_prompt( prompt["instruction"], prompt["input"], prompt["response"], @@ -80,4 +84,7 @@ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): def tokenize_prompt(self, prompt): - pass + try: + return self.prompter.build_prompt(prompt["conversations"], self.tokenizer) + except (KeyError, AssertionError) as e: + raise InvalidDataException(str(e)) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index baa3d9e66..9f4742dd7 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,10 +1,160 @@ +import copy +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any, Union + +IGNORE_TOKEN_ID = -100 + + class AlpacaPrompter: - pass + prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n" + response_split = "### Response:" + + def build_prompt( + self, + instruction: str, + input: Union[None, str] = None, + output: Union[None, str] = None, + ) -> str: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = self.prompt_input.format( + instruction=instruction, input=input + ) + else: + res = self.prompt_no_input.format( + instruction=instruction + ) + if output: + res = f"{res}{output}" + return res + + def get_response(self, output: str) -> str: + return output.split(self.response_split)[1].strip() + + +class GPTeacherPrompter(AlpacaPrompter): + ... + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + DOLLY = auto() + + +# TODO clean this 💩 up +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + def get_prompt(self): + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + ) + + def append_message(self, role, message): + self.messages.append([role, message]) + + +conv_vicuna_v1_1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=["USER", "ASSISTANT"], + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) class ShareGPTPrompter: - pass + def build_prompt( + self, + source, + tokenizer + ): + if len(source) < 2: + # If there isn't a back and forth conversation, ignore it + # also happens on the data splitting leaving empty conversations + raise IndexError + conv = conv_vicuna_v1_1.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} -class GPTeacherPrompter: - pass + try: + # Apply prompt templates + if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + except IndexError as e: + # sometimes there is a bing or system chat + raise e + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2] + conv.append_message(role, sentence["value"]) + conversation = conv.get_prompt() + + # Tokenize conversations + tokenized_result = tokenizer( + conversation, + truncation=True, + max_length=2048, # FIXME + padding=False, + return_tensors=None, + ) + target = copy.deepcopy(tokenized_result["input_ids"]) + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + + rounds = conversation.split(conv.sep2) + cur_len = 1 + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(tokenizer(rou)["input_ids"]) + instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2 + target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len + + cur_len += round_len + target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) + attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]] + + return dict(input_ids=tokenized_result["input_ids"], labels=target, + attention_mask=attention_mask)