fix lora target module, require explicit flash attention, fix min logging steps, don't use adam8bit for int4, hash prepared datasets, support hf hub datasets

This commit is contained in:
Wing Lian
2023-04-17 18:01:12 -04:00
parent 4131183115
commit 87e073d0de
4 changed files with 93 additions and 33 deletions

View File

@@ -21,7 +21,7 @@ lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: lora_target_modules:
- q_proj - q_proj
- w_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: llama-65b-lora wandb_project: llama-65b-lora
wandb_watch: wandb_watch:

41
configs/llama_7B_4bit.yml Normal file
View File

@@ -0,0 +1,41 @@
base_model: decapoda-research/llama-7b-hf-int4
base_model_config: decapoda-research/llama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-test
batch_size: 8
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
local_rank:
load_4bit: true

View File

@@ -21,7 +21,7 @@ lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: lora_target_modules:
- q_proj - q_proj
- w_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: llama-7b-lora wandb_project: llama-7b-lora
wandb_watch: wandb_watch:

View File

@@ -4,6 +4,7 @@ import os
import random import random
import signal import signal
import sys import sys
from hashlib import md5
from pathlib import Path from pathlib import Path
import bitsandbytes as bnb import bitsandbytes as bnb
@@ -13,6 +14,7 @@ import transformers
import yaml import yaml
from attrdict import AttrDefault from attrdict import AttrDefault
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
from huggingface_hub.hf_api import DatasetInfo
from torch import nn from torch import nn
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -20,6 +22,7 @@ from transformers import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaTokenizer, LlamaTokenizer,
EarlyStoppingCallback, EarlyStoppingCallback,
GenerationConfig,
) )
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
@@ -43,7 +46,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def setup_wandb_env_vars(cfg): def setup_wandb_env_vars(cfg):
if len(cfg.wandb_project) > 0: if cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
@@ -61,7 +64,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
if adapter != "lora": if adapter != "lora":
raise NotImplementedError(f"{adapter} peft adapter not available") raise NotImplementedError(f"{adapter} peft adapter not available")
if "llama" in base_model: if "llama" in base_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False: if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn from axolotl.flash_attn import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn() replace_llama_attn_with_flash_attn()
@@ -138,11 +141,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if load_in_8bit: if load_in_8bit and not cfg.load_4bit:
model = prepare_model_for_int8_training(model) model = prepare_model_for_int8_training(model)
lora_config = LoraConfig( lora_config = LoraConfig(
@@ -227,14 +231,19 @@ def check_dataset_labels(dataset, tokenizer):
def do_inference(cfg, model, tokenizer): def do_inference(cfg, model, tokenizer):
tokenizer.add_special_tokens({'unk_token': '<unk>'})
tokenizer.add_special_tokens({'bos_token': '<s>'})
tokenizer.add_special_tokens({'eos_token': '</s>'})
instruction = "Tell me a joke about dromedaries." instruction = "Tell me a joke about dromedaries."
input = "" input = ""
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input) prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
generated = model.generate(inputs=batch["input_ids"], # gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(inputs=batch["input_ids"].to("cuda"),
do_sample=True, use_cache=True, do_sample=True, use_cache=True,
repetition_penalty=1.1, repetition_penalty=1.1,
max_new_tokens=100, max_new_tokens=100,
@@ -277,7 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )
warmup_steps = min(int(0.03 * total_num_steps), 100) warmup_steps = min(int(0.03 * total_num_steps), 100)
logging_steps = min(int(0.005 * total_num_steps), 10) logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
training_arguments_kwargs = {} training_arguments_kwargs = {}
@@ -325,21 +334,24 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
}, },
] ]
adam_bnb_optim = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
# TODO optionally use torch.optim.OneCycleLR
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
adam_bnb_optim,
training_args.warmup_steps,
total_num_steps,
)
trainer_kwargs = {} trainer_kwargs = {}
if cfg.load_in_8bit and not cfg.load_4bit:
adam_bnb_optim = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
# TODO optionally use torch.optim.OneCycleLR
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
adam_bnb_optim,
training_args.warmup_steps,
total_num_steps,
)
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience, cfg.early_stopping_patience,
@@ -351,7 +363,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
args=training_args, args=training_args,
optimizers=(adam_bnb_optim, lr_scheduler),
data_collator=transformers.DataCollatorForSeq2Seq( data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
), ),
@@ -412,7 +423,11 @@ def train(
do_inference(cfg, model, tokenizer) do_inference(cfg, model, tokenizer)
return return
if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")): max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
if any(prepared_ds_path.glob("*")):
logging.info("Loading prepared dataset from disk...") logging.info("Loading prepared dataset from disk...")
dataset = load_from_disk(cfg.dataset_prepared_path) dataset = load_from_disk(cfg.dataset_prepared_path)
logging.info("Prepared dataset loaded from disk...") logging.info("Prepared dataset loaded from disk...")
@@ -420,13 +435,20 @@ def train(
logging.info("Loading raw datasets...") logging.info("Loading raw datasets...")
datasets = [] datasets = []
for d in cfg.datasets: for d in cfg.datasets:
ds_from_hub = False
try:
ds = load_dataset(d.path, streaming=True)
ds_from_hub = True
except FileNotFoundError:
pass
# prefer local dataset, even if hub exists
if Path(d.path).exists(): if Path(d.path).exists():
ds: IterableDataset = load_dataset( ds: IterableDataset = load_dataset(
"json", data_files=d.path, streaming=True, split=None "json", data_files=d.path, streaming=True, split=None
) )
# elif d.name and d.path: elif ds_from_hub:
# # TODO load from huggingface hub, but it only seems to support arrow or parquet atm ds = load_dataset(d.path, streaming=True)
# ds = load_dataset(d.path, split=None, data_files=d.name)
else: else:
raise Exception("unhandled dataset load") raise Exception("unhandled dataset load")
@@ -449,7 +471,7 @@ def train(
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
constant_len_dataset = ConstantLengthDataset( constant_len_dataset = ConstantLengthDataset(
tokenizer, datasets, seq_length=cfg.sequence_len tokenizer, datasets, seq_length=max_packed_sequence_len,
) )
logging.info("merging, packing, shuffling, and splitting master dataset") logging.info("merging, packing, shuffling, and splitting master dataset")
dataset = Dataset.from_list( dataset = Dataset.from_list(
@@ -457,11 +479,8 @@ def train(
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42) ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
if cfg.local_rank == 0: if cfg.local_rank == 0:
logging.info("Saving prepared dataset to disk...") logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
if cfg.dataset_prepared_path: dataset.save_to_disk(prepared_ds_path)
dataset.save_to_disk(cfg.dataset_prepared_path)
else:
dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH)
if prepare_ds_only: if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...") logging.info("Finished preparing dataset. Exiting...")