From 34c99f9812bcd9dff4efb5bd2e8410eacf00d749 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 22:37:23 -0400 Subject: [PATCH 1/8] fixes to make qlora actually work --- src/axolotl/utils/models.py | 4 ++-- src/axolotl/utils/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index de04e9333..34a02e1dd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -248,7 +248,7 @@ def load_model( if ( (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora" - ) and not cfg.load_4bit: + ) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit): logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) @@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter == "lora" or adapter == "qlora": + if adapter in ["lora" , "qlora"]: return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e15bbe14a..285075109 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0 + if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0 callbacks.append(SavePeftModelCallback) data_collator_kwargs = { From 3f6017db9e88dcff1011c38e6aa37888faca4f09 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 22:39:13 -0400 Subject: [PATCH 2/8] qlora merge and load requires that base model isn't loaded in 4 or 8 bit --- README.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f79a49a1f..28cbf21b8 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ ## Quickstart ⚡ -**Requirements**: Python 3.9. +**Requirements**: Python 3.9. ```bash git clone https://github.com/OpenAccess-AI-Collective/axolotl @@ -45,7 +45,7 @@ accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml \ ### Environment -- Docker +- Docker ```bash docker run --gpus '"all"' --rm -it winglian/axolotl:main ``` @@ -332,7 +332,7 @@ seed: ### Accelerate -Configure accelerate +Configure accelerate ```bash accelerate config @@ -363,12 +363,18 @@ Pass the appropriate flag to the train command: ### Merge LORA to base -Add below flag to train command above +Add below flag to train command above (and using LoRA) ```bash --merge_lora --lora_model_dir="./completed-model" ``` +Add below flag to train command above (and using QLoRA) + +```bash +--merge_lora --lora_model_dir="./completed-model" --load_in_8bit False --load_in_4bit False +``` + ## Common Errors 🧰 > Cuda out of memory @@ -383,7 +389,7 @@ Please reduce any below Try set `fp16: true` ## Need help? 🙋‍♂️ - + Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you ## Contributing 🤝 From 7b5e762be2a77e4584fbf5d87ea420407fdd9cdc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 22:40:27 -0400 Subject: [PATCH 3/8] fix merge conflict failure, black format --- src/axolotl/utils/data.py | 4 ---- src/axolotl/utils/models.py | 8 +++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 12b4f74a0..8e333ca8b 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -112,14 +112,10 @@ def load_tokenized_prepared_datasets( raise Exception("unhandled dataset load") # support for using a subset of the data if d.shards: -<<<<<<< Updated upstream - ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) -======= if "train" in ds: ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) else: ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0) ->>>>>>> Stashed changes d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 34a02e1dd..93c111a78 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -247,8 +247,10 @@ def load_model( model.resize_token_embeddings(embeddings_len) if ( - (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora" - ) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit): + ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") + and not cfg.load_4bit + and (load_in_8bit or cfg.load_in_4bit) + ): logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) @@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter in ["lora" , "qlora"]: + if adapter in ["lora", "qlora"]: return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg) From e7e1a777bdf529ede1eb84c1c8df3830156cae34 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 22:45:41 -0400 Subject: [PATCH 4/8] fix bool args according to python fire docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 28cbf21b8..af32eb64e 100644 --- a/README.md +++ b/README.md @@ -372,7 +372,7 @@ Add below flag to train command above (and using LoRA) Add below flag to train command above (and using QLoRA) ```bash ---merge_lora --lora_model_dir="./completed-model" --load_in_8bit False --load_in_4bit False +--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False ``` ## Common Errors 🧰 From 1987e5cf569b35b78e960d47cfbae45d815ef493 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 22:55:13 -0400 Subject: [PATCH 5/8] qlora and 4bit check so we are able to merge and unload --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 93c111a78..939a312d5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -85,7 +85,7 @@ def load_model( raise e model_kwargs = {} - if cfg.adapter == "qlora": + if cfg.adapter == "qlora" and cfg.load_in_4bit: model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, llm_int8_threshold=6.0, From 48f4c0571e27fe6cddd5b1f649e378f9db32f7a6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 23:02:03 -0400 Subject: [PATCH 6/8] fix validation for qlora merge --- src/axolotl/utils/validation.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 9bef37406..d2cb572f3 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,9 +1,14 @@ def validate_config(cfg): if cfg.adapter == "qlora": - assert cfg.load_in_8bit is False - assert cfg.load_4bit is False - assert cfg.load_in_4bit is True - pass + if cfg.merge_lora: + # can't merge qlora if loaded in 8bit or 4bit + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is False + else: + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is True # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 From a4f12415a0c58f943449ad45fa5b80063950c2ae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 23:10:11 -0400 Subject: [PATCH 7/8] update readme and add typehints --- README.md | 8 +------- src/axolotl/utils/data.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index af32eb64e..b33235170 100644 --- a/README.md +++ b/README.md @@ -363,13 +363,7 @@ Pass the appropriate flag to the train command: ### Merge LORA to base -Add below flag to train command above (and using LoRA) - -```bash ---merge_lora --lora_model_dir="./completed-model" -``` - -Add below flag to train command above (and using QLoRA) +Add below flag to train command above ```bash --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 8e333ca8b..b2045c229 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,6 +1,7 @@ import logging from hashlib import md5 from pathlib import Path +from typing import Union from datasets import ( load_from_disk, @@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets( logging.info("Loading raw datasets...") datasets = [] for d in cfg.datasets: - ds = None + ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: load_dataset(d.path, streaming=True, use_auth_token=True) @@ -90,32 +91,32 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): - ds: IterableDataset = load_dataset( + ds: Dataset = load_dataset( "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds = load_dataset( + ds: Dataset = load_dataset( d.path, streaming=False, data_files=d.data_files, use_auth_token=True, ) else: - ds = load_dataset(d.path, streaming=False, use_auth_token=True) + ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds = load_dataset("json", data_files=fp, streaming=False, split=None) + ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") # support for using a subset of the data if d.shards: if "train" in ds: - ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) + ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) else: - ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0) + ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] From a5bf83868512eee7bed8f4208be22ec4b858af87 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 26 May 2023 00:09:55 -0400 Subject: [PATCH 8/8] add logging and make sure model unloads to float16 --- scripts/finetune.py | 1 + src/axolotl/utils/validation.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/scripts/finetune.py b/scripts/finetune.py index b79079e26..0c8727401 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -176,6 +176,7 @@ def train( if "merge_lora" in kwargs and cfg.adapter is not None: logging.info("running merge of LoRA with base model") model = model.merge_and_unload() + model.to(dtype=torch.float16) if cfg.local_rank == 0: logging.info("saving merged model") diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index d2cb572f3..babf246f5 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,3 +1,6 @@ +import logging + + def validate_config(cfg): if cfg.adapter == "qlora": if cfg.merge_lora: @@ -9,6 +12,9 @@ def validate_config(cfg): assert cfg.load_in_8bit is False assert cfg.load_4bit is False assert cfg.load_in_4bit is True + if cfg.load_in_8bit and cfg.adapter == "lora": + logging.warning("we recommend setting `load_in_8bit: true`") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25