diff --git a/requirements.txt b/requirements.txt index ae6689680..9703000eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git +bitsandbytes>=0.39.0 attrdict fire PyYAML==6.0 black -bitsandbytes==0.37.2 datasets accelerate>=0.19.0 sentencepiece diff --git a/scripts/finetune.py b/scripts/finetune.py index cb9d7e94e..b79079e26 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,6 +14,7 @@ from attrdict import AttrDefault # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels +from axolotl.utils.validation import validate_config project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -158,6 +159,8 @@ def train( cfg.fp16 = True cfg.bf16 = False + validate_config(cfg) + # Load the model and tokenizer logging.info("loading model, tokenizer, and peft_config...") model, tokenizer, peft_config = load_model( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 3ae0a0bd4..c79c3afa7 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -11,6 +11,7 @@ class PromptStyle(Enum): instruct = "instruct" chat = "chat" + class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" @@ -50,6 +51,10 @@ class AlpacaPrompter: return output.split(self.response_split)[1].strip() +class UnpromptedPrompter(AlpacaPrompter): + system_prompt = "" + system_no_input_prompt = "" + class JeopardyPrompter(AlpacaPrompter): prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2ceaa4d99..8d9525fa5 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -98,6 +98,11 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") + # support for using a subset of the data + if d.shards: + ds = ds.shuffle(seed=42)["train"].shard( + num_shards=cfg.shards, index=0 + ) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 934f2f74c..b2955ad1a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING import torch import transformers +from torch import nn from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, - AutoConfig, + AutoConfig, BitsAndBytesConfig, ) try: @@ -81,6 +82,16 @@ def load_model( logging.exception(e) raise e + model_kwargs = {} + if cfg.adapter == "qlora": + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) try: if cfg.load_4bit and is_llama_derived_model: from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram @@ -123,8 +134,10 @@ def load_model( model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, + **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # This is a WIP, still an issue with the backward pass @@ -156,9 +169,11 @@ def load_model( model = getattr(transformers, model_type).from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) else: config = AutoConfig.from_pretrained( @@ -169,9 +184,11 @@ def load_model( base_model, config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) except Exception as e: logging.error( @@ -184,6 +201,7 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) if not tokenizer: @@ -225,7 +243,7 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) - if cfg.adapter and load_in_8bit and not cfg.load_4bit: + if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit: logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) @@ -270,7 +288,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter == "lora": + if adapter == "lora" or adapter == "qlora": return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg)