From 34c99f9812bcd9dff4efb5bd2e8410eacf00d749 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 22:37:23 -0400 Subject: [PATCH] fixes to make qlora actually work --- src/axolotl/utils/models.py | 4 ++-- src/axolotl/utils/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index de04e9333..34a02e1dd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -248,7 +248,7 @@ def load_model( if ( (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora" - ) and not cfg.load_4bit: + ) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit): logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) @@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter == "lora" or adapter == "qlora": + if adapter in ["lora" , "qlora"]: return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e15bbe14a..285075109 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0 + if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0 callbacks.append(SavePeftModelCallback) data_collator_kwargs = {