From 3b4d055edd5c2c180ed6f52d0fa15f4dc0501508 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 May 2023 23:13:33 -0400 Subject: [PATCH 1/5] integrate qlora? maybe? --- requirements.txt | 2 +- src/axolotl/utils/models.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index ae6689680..e45037fc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git +bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git attrdict fire PyYAML==6.0 black -bitsandbytes==0.37.2 datasets accelerate>=0.19.0 sentencepiece diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 934f2f74c..992abf3ed 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING import torch import transformers +from torch import nn from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, - AutoConfig, + AutoConfig, BitsAndBytesConfig, ) try: @@ -81,6 +82,16 @@ def load_model( logging.exception(e) raise e + model_kwargs = {} + if cfg.adapter == "qlora": + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) try: if cfg.load_4bit and is_llama_derived_model: from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram @@ -125,6 +136,7 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, + **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # This is a WIP, still an issue with the backward pass @@ -159,6 +171,7 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) else: config = AutoConfig.from_pretrained( @@ -172,6 +185,7 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) except Exception as e: logging.error( @@ -184,8 +198,24 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) + """### Post-processing on the model + Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons. + """ + if cfg.adapter == "qlora": + for param in model.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + class CastOutputToFloat(nn.Sequential): + def forward(self, x): + return super().forward(x).to(torch.float32) + + model.lm_head = CastOutputToFloat(model.lm_head) + if not tokenizer: try: if is_llama_derived_model and "LlamaTokenizer" in globals(): @@ -270,7 +300,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter == "lora": + if adapter == "lora" or adapter == "qlora": return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg) From b9d07aa95a7a2291907cfbeabc63f95c61570ed9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 May 2023 11:33:41 -0400 Subject: [PATCH 2/5] prepare does all this already for qlora? --- src/axolotl/utils/models.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 992abf3ed..1bcc4b0bc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -204,17 +204,17 @@ def load_model( """### Post-processing on the model Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons. """ - if cfg.adapter == "qlora": - for param in model.parameters(): - param.requires_grad = False # freeze the model - train adapters later - if param.ndim == 1: - # cast the small parameters (e.g. layernorm) to fp32 for stability - param.data = param.data.to(torch.float32) - class CastOutputToFloat(nn.Sequential): - def forward(self, x): - return super().forward(x).to(torch.float32) - - model.lm_head = CastOutputToFloat(model.lm_head) + # if cfg.adapter == "qlora": + # for param in model.parameters(): + # param.requires_grad = False # freeze the model - train adapters later + # if param.ndim == 1: + # # cast the small parameters (e.g. layernorm) to fp32 for stability + # param.data = param.data.to(torch.float32) + # class CastOutputToFloat(nn.Linear): + # def forward(self, x): + # return super().forward(x).to(torch.float32) + # + # model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias) if not tokenizer: try: @@ -255,7 +255,7 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) - if cfg.adapter and load_in_8bit and not cfg.load_4bit: + if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit: logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) From e8aacfbd7caafe0e3b4601fb340725c5b936bae4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 24 May 2023 14:33:18 -0400 Subject: [PATCH 3/5] more qlora support --- src/axolotl/prompters.py | 5 +++++ src/axolotl/utils/data.py | 5 +++++ src/axolotl/utils/models.py | 3 +++ 3 files changed, 13 insertions(+) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 3ae0a0bd4..c79c3afa7 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -11,6 +11,7 @@ class PromptStyle(Enum): instruct = "instruct" chat = "chat" + class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" @@ -50,6 +51,10 @@ class AlpacaPrompter: return output.split(self.response_split)[1].strip() +class UnpromptedPrompter(AlpacaPrompter): + system_prompt = "" + system_no_input_prompt = "" + class JeopardyPrompter(AlpacaPrompter): prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2ceaa4d99..8d9525fa5 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -98,6 +98,11 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") + # support for using a subset of the data + if d.shards: + ds = ds.shuffle(seed=42)["train"].shard( + num_shards=cfg.shards, index=0 + ) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1bcc4b0bc..571f1c6dd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -134,6 +134,7 @@ def load_model( model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, **model_kwargs, @@ -168,6 +169,7 @@ def load_model( model = getattr(transformers, model_type).from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, @@ -182,6 +184,7 @@ def load_model( base_model, config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, From 7e81ca720bdb6ca441372e3ac98175b1c1c2736e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 24 May 2023 15:44:48 -0400 Subject: [PATCH 4/5] Update requirements.txt Co-authored-by: NanoCode012 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e45037fc0..9703000eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git -bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git +bitsandbytes>=0.39.0 attrdict fire PyYAML==6.0 From 1f5d83ea729272ad131d379310c4e08df752ef5f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 24 May 2023 22:47:33 -0400 Subject: [PATCH 5/5] remove un-needed code, add validation --- scripts/finetune.py | 3 +++ src/axolotl/utils/models.py | 15 --------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index cb9d7e94e..b79079e26 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,6 +14,7 @@ from attrdict import AttrDefault # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels +from axolotl.utils.validation import validate_config project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -158,6 +159,8 @@ def train( cfg.fp16 = True cfg.bf16 = False + validate_config(cfg) + # Load the model and tokenizer logging.info("loading model, tokenizer, and peft_config...") model, tokenizer, peft_config = load_model( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 571f1c6dd..b2955ad1a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -204,21 +204,6 @@ def load_model( **model_kwargs, ) - """### Post-processing on the model - Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons. - """ - # if cfg.adapter == "qlora": - # for param in model.parameters(): - # param.requires_grad = False # freeze the model - train adapters later - # if param.ndim == 1: - # # cast the small parameters (e.g. layernorm) to fp32 for stability - # param.data = param.data.to(torch.float32) - # class CastOutputToFloat(nn.Linear): - # def forward(self, x): - # return super().forward(x).to(torch.float32) - # - # model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias) - if not tokenizer: try: if is_llama_derived_model and "LlamaTokenizer" in globals():