diff --git a/configs/gpt_neox_20b.yml b/configs/gpt_neox_20b.yml index 730afb72c..25fdae53b 100644 --- a/configs/gpt_neox_20b.yml +++ b/configs/gpt_neox_20b.yml @@ -1,24 +1,25 @@ base_model: EleutherAI/gpt-neox-20b +base_model_config: EleutherAI/gpt-neox-20b base_model_ignore_patterns: pytorch* # prefer safetensors model_type: GPTNeoXForCausalLM tokenizer_type: AutoTokenizer -load_in_8bit: true +load_in_8bit: false +load_in_4bit: true +load_4bit: false datasets: - - path: nomic-ai/gpt4all-j-prompt-generations + - path: vicgalle/alpaca-gpt4 type: alpaca - shards: 4 - shards_index: 0 dataset_prepared_path: last_run_prepared val_set_size: 0.05 -adapter: lora +adapter: lora_model_dir: sequence_len: 2048 max_packed_sequence_len: 2048 -lora_r: 8 +lora_r: 64 lora_alpha: 32 -lora_dropout: 0.05 +lora_dropout: 0.0 lora_target_modules: - - query_key_value +lora_target_linear: true lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific wandb_project: gpt4all-neox-20b wandb_watch: @@ -26,14 +27,19 @@ wandb_run_id: wandb_log_model: output_dir: ./gpt4all-neox-20b gradient_accumulation_steps: 1 -micro_batch_size: 4 +micro_batch_size: 2 num_epochs: 5 learning_rate: 0.00003 -lr_scheduler: one_cycle +optimizer: paged_adamw_32bit +lr_scheduler: cosine train_on_inputs: false group_by_length: false -bf16: True -tf32: True +bf16: false +fp16: false +float16: true +tf32: true +flash_optimum: true early_stopping_patience: resume_from_checkpoint: local_rank: +gradient_checkpointing: true diff --git a/requirements.txt b/requirements.txt index c9123fce8..d1b2f4555 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ sentencepiece wandb einops xformers +optimum # qlora things bert-score==0.3.13 evaluate==0.4.0 diff --git a/scripts/finetune.py b/scripts/finetune.py index fa2dcf903..a5b5e7c85 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -6,6 +6,7 @@ import os import random import signal import sys +from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer # add src to the pythonpath so we don't need to pip install this +from optimum.bettertransformer import BetterTransformer + from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.validation import validate_config @@ -264,12 +267,14 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: + def terminate_handler(signum, frame, model): + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) + model.save_pretrained(cfg.output_dir) + sys.exit(0) signal.signal( signal.SIGINT, - lambda signal, frame: ( - model.save_pretrained(cfg.output_dir), - sys.exit(0), - ), + lambda signum, frame: terminate_handler(signum, frame, model) ) logging.info("Starting trainer...") @@ -299,6 +304,8 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.local_rank == 0: + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir) # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1acaf6ab3..11b4629ec 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb import torch import transformers from transformers import PreTrainedModel # noqa: F401 -from transformers import ( # noqa: F401 +from optimum.bettertransformer import BetterTransformer +from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -137,7 +138,7 @@ def load_model( if cfg.bf16: torch_dtype = torch.bfloat16 - elif cfg.load_in_8bit or cfg.fp16: + elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: torch_dtype = torch.float16 else: torch_dtype = torch.float32 @@ -342,6 +343,9 @@ def load_model( logging.warning("there are no parameters that require gradient updates") model.config.use_cache = False + if cfg.flash_optimum: + model = BetterTransformer.transform(model) + # TODO resume_from_checkpoint handling return model, lora_config diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 04ffc4c1b..ba5feafe8 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -57,6 +57,14 @@ def validate_config(cfg): if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: raise ValueError("FSDP is not supported for falcon models") + if cfg.flash_optimum is True: + if cfg.adapter: + logging.warning("BetterTransformers probably doesn't work with PEFT adapters") + if cfg.fp16 or cfg.bf16: + raise ValueError("AMP is not supported with BetterTransformer") + if cfg.float16 is not True: + logging.warning("You should probably set float16 to true") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25