add support for opimum bettertransformers

This commit is contained in:
Wing Lian
2023-05-27 17:57:29 -04:00
parent a6f5e5eaec
commit adea682316
5 changed files with 48 additions and 22 deletions

View File

@@ -1,24 +1,25 @@
base_model: EleutherAI/gpt-neox-20b base_model: EleutherAI/gpt-neox-20b
base_model_config: EleutherAI/gpt-neox-20b
base_model_ignore_patterns: pytorch* # prefer safetensors base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
load_in_8bit: true load_in_8bit: false
load_in_4bit: true
load_4bit: false
datasets: datasets:
- path: nomic-ai/gpt4all-j-prompt-generations - path: vicgalle/alpaca-gpt4
type: alpaca type: alpaca
shards: 4
shards_index: 0
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter:
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 8 lora_r: 64
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.0
lora_target_modules: lora_target_modules:
- query_key_value lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: gpt4all-neox-20b wandb_project: gpt4all-neox-20b
wandb_watch: wandb_watch:
@@ -26,14 +27,19 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./gpt4all-neox-20b output_dir: ./gpt4all-neox-20b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 2
num_epochs: 5 num_epochs: 5
learning_rate: 0.00003 learning_rate: 0.00003
lr_scheduler: one_cycle optimizer: paged_adamw_32bit
lr_scheduler: cosine
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: True bf16: false
tf32: True fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
gradient_checkpointing: true

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb wandb
einops einops
xformers xformers
optimum
# qlora things # qlora things
bert-score==0.3.13 bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0

View File

@@ -6,6 +6,7 @@ import os
import random import random
import signal import signal
import sys import sys
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
@@ -47,10 +50,11 @@ def choose_device(cfg):
return "cpu" return "cpu"
cfg.device = get_device() cfg.device = get_device()
if cfg.device == "cuda": if cfg.device_map != "auto":
cfg.device_map = {"": cfg.local_rank} if cfg.device == "cuda":
else: cfg.device_map = {"": cfg.local_rank}
cfg.device_map = {"": cfg.device} else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]: def get_multi_line_input() -> Optional[str]:
@@ -253,12 +257,14 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(signum, frame, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT,
lambda signal, frame: ( lambda signum, frame: terminate_handler(signum, frame, model)
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -285,6 +291,8 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0: if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -11,7 +11,8 @@ import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from transformers import PreTrainedModel # noqa: F401 from transformers import PreTrainedModel # noqa: F401
from transformers import ( # noqa: F401 from optimum.bettertransformer import BetterTransformer
from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@@ -117,7 +118,7 @@ def load_model(
if cfg.bf16: if cfg.bf16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16 torch_dtype = torch.float16
else: else:
torch_dtype = torch.float32 torch_dtype = torch.float32
@@ -304,6 +305,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates") logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -48,6 +48,13 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
) )
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True:
logging.warning("You should probably set float16 to true")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25