add support for opimum bettertransformers

This commit is contained in:
Wing Lian
2023-05-27 17:57:29 -04:00
parent a6f5e5eaec
commit adea682316
5 changed files with 48 additions and 22 deletions

View File

@@ -1,24 +1,25 @@
base_model: EleutherAI/gpt-neox-20b
base_model_config: EleutherAI/gpt-neox-20b
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_8bit: false
load_in_4bit: true
load_4bit: false
datasets:
- path: nomic-ai/gpt4all-j-prompt-generations
- path: vicgalle/alpaca-gpt4
type: alpaca
shards: 4
shards_index: 0
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_dropout: 0.0
lora_target_modules:
- query_key_value
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: gpt4all-neox-20b
wandb_watch:
@@ -26,14 +27,19 @@ wandb_run_id:
wandb_log_model:
output_dir: ./gpt4all-neox-20b
gradient_accumulation_steps: 1
micro_batch_size: 4
micro_batch_size: 2
num_epochs: 5
learning_rate: 0.00003
lr_scheduler: one_cycle
optimizer: paged_adamw_32bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb
einops
xformers
optimum
# qlora things
bert-score==0.3.13
evaluate==0.4.0

View File

@@ -6,6 +6,7 @@ import os
import random
import signal
import sys
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
@@ -47,10 +50,11 @@ def choose_device(cfg):
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
if cfg.device_map != "auto":
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
@@ -253,12 +257,14 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(signum, frame, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT,
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
lambda signum, frame: terminate_handler(signum, frame, model)
)
logging.info("Starting trainer...")
@@ -285,6 +291,8 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -11,7 +11,8 @@ import bitsandbytes as bnb
import torch
import transformers
from transformers import PreTrainedModel # noqa: F401
from transformers import ( # noqa: F401
from optimum.bettertransformer import BetterTransformer
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -117,7 +118,7 @@ def load_model(
if cfg.bf16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16:
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
@@ -304,6 +305,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling
return model, lora_config

View File

@@ -48,6 +48,13 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True:
logging.warning("You should probably set float16 to true")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25