diff --git a/README.md b/README.md index e490013a9..5e8d05490 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,18 @@ #### You know you're going to axolotl questions +## Getting Started -### Converting JSON data files to JSONL +- Download some datasets. + +```shell +curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json +curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json +curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json +curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json +``` + +- Convert the JSON data files to JSONL. ```shell python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl @@ -11,3 +21,13 @@ python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl ``` + +- Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples. + +```shell +shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl +``` + +- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml] +- Install python dependencies `pip3 install -r requirements.txt` +- Train! `python3 scripts/finetune.py`, make sure to choose the correct YAML config file diff --git a/configs/cerebras_1_3B_alpaca.yml b/configs/cerebras_1_3B_alpaca.yml new file mode 100644 index 000000000..d2f0bb3be --- /dev/null +++ b/configs/cerebras_1_3B_alpaca.yml @@ -0,0 +1,38 @@ +base_model: cerebras/Cerebras-GPT-1.3B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +load_in_8bit: true +datasets: + - path: data/alpaca_data_gpt4.jsonl + type: alpaca + - path: data/vicuna_cleaned.jsonl + type: sharegpt + - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl + type: gpteacher + - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl + type: gpteacher +val_set_size: 0.05 +adapter: lora +sequence_len: 2048 +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - c_attn +lora_fan_in_fan_out: false +wandb_project: pythia-1.4b-lora +wandb_watch: +wandb_run_name: +wandb_log_model: checkpoint +output_dir: ./lora-alpaca +batch_size: 32 +micro_batch_size: 4 +num_epochs: 5 +learning_rate: 0.0003 +train_on_inputs: false +group_by_length: false +bf16: True +tf32: True +resume_from_checkpoint: +local_rank: +deepspeed: diff --git a/requirements.txt b/requirements.txt index 6ae668860..048936baf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,9 @@ attrdict fire PyYAML==6.0 black +bitsandbytes +datasets +accelerate +sentencepiece +wandb +flash-attn diff --git a/scripts/finetune.py b/scripts/finetune.py index 1a7e384c3..23425d1c0 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -9,7 +9,7 @@ import fire import torch import transformers import yaml -from attrdict import AttrDict +from attrdict import AttrDefault from datasets import load_dataset, IterableDataset, Dataset from peft import ( LoraConfig, @@ -50,6 +50,11 @@ def setup_wandb_env_vars(cfg): def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): if adapter != "lora": raise NotImplementedError(f"{adapter} peft adapter not available") + if "llama" in base_model: + from axolotl.flash_attn import replace_llama_attn_with_flash_attn + + replace_llama_attn_with_flash_attn() + try: model = getattr(transformers, model_type).from_pretrained( base_model, @@ -99,24 +104,104 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): return model, tokenizer, lora_config +def choose_device(cfg): + def get_device(): + if torch.cuda.is_available(): + return "cuda" + else: + try: + if torch.backends.mps.is_available(): + return "mps" + except: + return "cpu" + + cfg.device = get_device() + if cfg.device == "cuda": + cfg.device_map = {"": cfg.local_rank} + else: + cfg.device_map = {"": cfg.device} + + +def check_dataset_labels(dataset, tokenizer): + from termcolor import colored + + # the dataset is already shuffled, so let's just check the first 5 elements + for idx in range(5): + # Get the input_ids, labels, and attention_mask from the dataset + input_ids = dataset[idx]["input_ids"] + labels = dataset[idx]["labels"] + attention_mask = dataset[idx]["attention_mask"] + + # You can compare the input_ids and labels element-wise + # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 + colored_tokens = [] + for i, (input_id, label_id, mask) in enumerate( + zip(input_ids, labels, attention_mask) + ): + decoded_input_token = tokenizer.decode(input_id) + # Choose the color based on whether the label has the ignore value or not + color = ( + "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") + ) + colored_token = colored(decoded_input_token, color) + colored( + f"({label_id}, {mask})", "white" + ) + colored_tokens.append(colored_token) + + print(" ".join(colored_tokens)) + print("\n\n\n") + + +def choose_config(path: Path): + yaml_files = [file for file in path.glob("*.yml")] + + if not yaml_files: + raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?") + + print("Choose a YAML file:") + for idx, file in enumerate(yaml_files): + print(f"{idx + 1}. {file}") + + chosen_file = None + while chosen_file is None: + try: + choice = int(input("Enter the number of your choice: ")) + if 1 <= choice <= len(yaml_files): + chosen_file = yaml_files[choice - 1] + else: + print("Invalid choice. Please choose a number from the list.") + except ValueError: + print("Invalid input. Please enter a number.") + + return chosen_file + + def train( - config: Path = Path("configs/pythia_1_2B_alpaca.yml"), + config: Path = Path("configs/"), **kwargs, ): + if config.is_dir(): + config = choose_config(config) + # load the config from the yaml file with open(config, "r") as f: - cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader)) + cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value - for k, v in enumerate(kwargs): - if k in cfg: - cfg.k = v + cfg_keys = dict(cfg).keys() + for k in kwargs: + if k in cfg_keys: + # handle booleans + if isinstance(cfg[k], bool): + cfg[k] = bool(kwargs[k]) + else: + cfg[k] = kwargs[k] # setup some derived config / hyperparams cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size - cfg.device_map = "auto" cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + choose_device(cfg) cfg.ddp = cfg.world_size != 1 if cfg.ddp: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} @@ -163,6 +248,8 @@ def train( train_dataset = constant_len_dataset["train"] eval_dataset = constant_len_dataset["test"] + # check_dataset_labels(eval_dataset, tokenizer) + total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) @@ -240,6 +327,7 @@ def train( if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) + # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model signal.signal( signal.SIGINT, lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), diff --git a/setup.cfg b/setup.cfg index a35fb96dc..8f0ba619a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,12 @@ install_requires = fire PyYAML == 6.0 black + bitsandbytes + datasets + accelerate + sentencepiece + wandb + flash-attn [options.packages.find] where = src diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py new file mode 100644 index 000000000..c1ceec788 --- /dev/null +++ b/src/axolotl/flash_attn.py @@ -0,0 +1,116 @@ +# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py + +from typing import List, Optional, Tuple + +import torch +from torch import nn + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from einops import rearrange + +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, "past_key_value is not supported" + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_s = q_len + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + ) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len + ), + "b s (h d) -> b s h d", + h=nheads, + ) + return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index b0cb0d8ed..ce7da8f7d 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -88,5 +88,5 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): def tokenize_prompt(self, prompt): try: return self.prompter.build_prompt(prompt["conversations"], self.tokenizer) - except (KeyError, AssertionError) as e: + except (KeyError, AssertionError, IndexError) as e: raise InvalidDataException(str(e))