diff --git a/README.md b/README.md index a0e036566..8a1c945d8 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ ## Quickstart ⚡ -**Requirements**: Python 3.9. +**Requirements**: Python 3.9. ```bash git clone https://github.com/OpenAccess-AI-Collective/axolotl @@ -45,7 +45,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \ ### Environment -- Docker +- Docker ```bash docker run --gpus '"all"' --rm -it winglian/axolotl:main ``` @@ -334,7 +334,7 @@ strict: ### Accelerate -Configure accelerate +Configure accelerate ```bash accelerate config @@ -368,7 +368,7 @@ Pass the appropriate flag to the train command: Add below flag to train command above ```bash ---merge_lora --lora_model_dir="./completed-model" +--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False ``` ## Common Errors 🧰 @@ -389,7 +389,7 @@ Try set `fp16: true` Try to turn off xformers. ## Need help? 🙋‍♂️ - + Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you ## Contributing 🤝 diff --git a/scripts/finetune.py b/scripts/finetune.py index b79079e26..0c8727401 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -176,6 +176,7 @@ def train( if "merge_lora" in kwargs and cfg.adapter is not None: logging.info("running merge of LoRA with base model") model = model.merge_and_unload() + model.to(dtype=torch.float16) if cfg.local_rank == 0: logging.info("saving merged model") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 12b4f74a0..b2045c229 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,6 +1,7 @@ import logging from hashlib import md5 from pathlib import Path +from typing import Union from datasets import ( load_from_disk, @@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets( logging.info("Loading raw datasets...") datasets = [] for d in cfg.datasets: - ds = None + ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: load_dataset(d.path, streaming=True, use_auth_token=True) @@ -90,36 +91,32 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): - ds: IterableDataset = load_dataset( + ds: Dataset = load_dataset( "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds = load_dataset( + ds: Dataset = load_dataset( d.path, streaming=False, data_files=d.data_files, use_auth_token=True, ) else: - ds = load_dataset(d.path, streaming=False, use_auth_token=True) + ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds = load_dataset("json", data_files=fp, streaming=False, split=None) + ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") # support for using a subset of the data if d.shards: -<<<<<<< Updated upstream - ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) -======= if "train" in ds: - ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) + ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) else: - ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0) ->>>>>>> Stashed changes + ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index de04e9333..939a312d5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -85,7 +85,7 @@ def load_model( raise e model_kwargs = {} - if cfg.adapter == "qlora": + if cfg.adapter == "qlora" and cfg.load_in_4bit: model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, llm_int8_threshold=6.0, @@ -247,8 +247,10 @@ def load_model( model.resize_token_embeddings(embeddings_len) if ( - (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora" - ) and not cfg.load_4bit: + ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") + and not cfg.load_4bit + and (load_in_8bit or cfg.load_in_4bit) + ): logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) @@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter == "lora" or adapter == "qlora": + if adapter in ["lora", "qlora"]: return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e15bbe14a..285075109 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0 + if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0 callbacks.append(SavePeftModelCallback) data_collator_kwargs = { diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 9bef37406..babf246f5 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,9 +1,20 @@ +import logging + + def validate_config(cfg): if cfg.adapter == "qlora": - assert cfg.load_in_8bit is False - assert cfg.load_4bit is False - assert cfg.load_in_4bit is True - pass + if cfg.merge_lora: + # can't merge qlora if loaded in 8bit or 4bit + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is False + else: + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is True + if cfg.load_in_8bit and cfg.adapter == "lora": + logging.warning("we recommend setting `load_in_8bit: true`") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25