diff --git a/configs/llama_65B_alpaca.yml b/configs/llama_65B_alpaca.yml index dcd64173d..917faa97a 100644 --- a/configs/llama_65B_alpaca.yml +++ b/configs/llama_65B_alpaca.yml @@ -21,7 +21,7 @@ lora_alpha: 16 lora_dropout: 0.05 lora_target_modules: - q_proj - - w_proj + - v_proj lora_fan_in_fan_out: false wandb_project: llama-65b-lora wandb_watch: diff --git a/configs/llama_7B_4bit.yml b/configs/llama_7B_4bit.yml new file mode 100644 index 000000000..422ad5724 --- /dev/null +++ b/configs/llama_7B_4bit.yml @@ -0,0 +1,41 @@ +base_model: decapoda-research/llama-7b-hf-int4 +base_model_config: decapoda-research/llama-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +load_in_8bit: true +datasets: + - path: vicgalle/alpaca-gpt4 + type: alpaca +dataset_prepared_path: data/last_run_prepared +val_set_size: 0.04 +adapter: lora +lora_model_dir: +sequence_len: 2048 +max_packed_sequence_len: 1024 +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - v_proj +# - k_proj +# - o_proj +lora_fan_in_fan_out: false +wandb_project: +wandb_watch: +wandb_run_id: +wandb_log_model: checkpoint +output_dir: ./lora-test +batch_size: 8 +micro_batch_size: 2 +num_epochs: 3 +learning_rate: 0.00003 +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true +gradient_checkpointing: false +early_stopping_patience: 3 +resume_from_checkpoint: +local_rank: +load_4bit: true diff --git a/configs/llama_7B_alpaca.yml b/configs/llama_7B_alpaca.yml index e884c4d38..20efd58d3 100644 --- a/configs/llama_7B_alpaca.yml +++ b/configs/llama_7B_alpaca.yml @@ -21,7 +21,7 @@ lora_alpha: 16 lora_dropout: 0.05 lora_target_modules: - q_proj - - w_proj + - v_proj lora_fan_in_fan_out: false wandb_project: llama-7b-lora wandb_watch: diff --git a/scripts/finetune.py b/scripts/finetune.py index 6f0313061..8c3d3457f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -4,6 +4,7 @@ import os import random import signal import sys +from hashlib import md5 from pathlib import Path import bitsandbytes as bnb @@ -13,6 +14,7 @@ import transformers import yaml from attrdict import AttrDefault from datasets import load_dataset, IterableDataset, Dataset, load_from_disk +from huggingface_hub.hf_api import DatasetInfo from torch import nn from transformers import ( AutoModelForCausalLM, @@ -20,6 +22,7 @@ from transformers import ( LlamaForCausalLM, LlamaTokenizer, EarlyStoppingCallback, + GenerationConfig, ) # add src to the pythonpath so we don't need to pip install this @@ -43,7 +46,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" def setup_wandb_env_vars(cfg): - if len(cfg.wandb_project) > 0: + if cfg.wandb_project and len(cfg.wandb_project) > 0: os.environ["WANDB_PROJECT"] = cfg.wandb_project cfg.use_wandb = True if cfg.wandb_watch and len(cfg.wandb_watch) > 0: @@ -61,7 +64,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a if adapter != "lora": raise NotImplementedError(f"{adapter} peft adapter not available") - if "llama" in base_model: + if "llama" in base_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() @@ -138,11 +141,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN + if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": tokenizer.add_special_tokens({"pad_token": "[PAD]"}) os.environ["TOKENIZERS_PARALLELISM"] = "false" - if load_in_8bit: + if load_in_8bit and not cfg.load_4bit: model = prepare_model_for_int8_training(model) lora_config = LoraConfig( @@ -227,14 +231,19 @@ def check_dataset_labels(dataset, tokenizer): def do_inference(cfg, model, tokenizer): + tokenizer.add_special_tokens({'unk_token': ''}) + tokenizer.add_special_tokens({'bos_token': ''}) + tokenizer.add_special_tokens({'eos_token': ''}) + instruction = "Tell me a joke about dromedaries." input = "" prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input) - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() with torch.no_grad(): - generated = model.generate(inputs=batch["input_ids"], + # gc = GenerationConfig() # TODO swap out and use this + generated = model.generate(inputs=batch["input_ids"].to("cuda"), do_sample=True, use_cache=True, repetition_penalty=1.1, max_new_tokens=100, @@ -277,7 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) warmup_steps = min(int(0.03 * total_num_steps), 100) - logging_steps = min(int(0.005 * total_num_steps), 10) + logging_steps = max(min(int(0.005 * total_num_steps), 10), 1) save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) training_arguments_kwargs = {} @@ -325,21 +334,24 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): }, ] - adam_bnb_optim = bnb.optim.Adam8bit( - optimizer_grouped_parameters, - betas=(training_args.adam_beta1, training_args.adam_beta2), - eps=training_args.adam_epsilon, - lr=training_args.learning_rate, - ) - - # TODO optionally use torch.optim.OneCycleLR - lr_scheduler = transformers.get_cosine_schedule_with_warmup( - adam_bnb_optim, - training_args.warmup_steps, - total_num_steps, - ) - trainer_kwargs = {} + + if cfg.load_in_8bit and not cfg.load_4bit: + adam_bnb_optim = bnb.optim.Adam8bit( + optimizer_grouped_parameters, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + lr=training_args.learning_rate, + ) + + # TODO optionally use torch.optim.OneCycleLR + lr_scheduler = transformers.get_cosine_schedule_with_warmup( + adam_bnb_optim, + training_args.warmup_steps, + total_num_steps, + ) + trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler) + if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( cfg.early_stopping_patience, @@ -351,7 +363,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, - optimizers=(adam_bnb_optim, lr_scheduler), data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), @@ -412,7 +423,11 @@ def train( do_inference(cfg, model, tokenizer) return - if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")): + max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len + max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len + ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest()) + prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + if any(prepared_ds_path.glob("*")): logging.info("Loading prepared dataset from disk...") dataset = load_from_disk(cfg.dataset_prepared_path) logging.info("Prepared dataset loaded from disk...") @@ -420,13 +435,20 @@ def train( logging.info("Loading raw datasets...") datasets = [] for d in cfg.datasets: + ds_from_hub = False + try: + ds = load_dataset(d.path, streaming=True) + ds_from_hub = True + except FileNotFoundError: + pass + + # prefer local dataset, even if hub exists if Path(d.path).exists(): ds: IterableDataset = load_dataset( "json", data_files=d.path, streaming=True, split=None ) - # elif d.name and d.path: - # # TODO load from huggingface hub, but it only seems to support arrow or parquet atm - # ds = load_dataset(d.path, split=None, data_files=d.name) + elif ds_from_hub: + ds = load_dataset(d.path, streaming=True) else: raise Exception("unhandled dataset load") @@ -449,7 +471,7 @@ def train( ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) constant_len_dataset = ConstantLengthDataset( - tokenizer, datasets, seq_length=cfg.sequence_len + tokenizer, datasets, seq_length=max_packed_sequence_len, ) logging.info("merging, packing, shuffling, and splitting master dataset") dataset = Dataset.from_list( @@ -457,11 +479,8 @@ def train( ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42) if cfg.local_rank == 0: - logging.info("Saving prepared dataset to disk...") - if cfg.dataset_prepared_path: - dataset.save_to_disk(cfg.dataset_prepared_path) - else: - dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH) + logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}") + dataset.save_to_disk(prepared_ds_path) if prepare_ds_only: logging.info("Finished preparing dataset. Exiting...")