fixes to make qlora actually work

This commit is contained in:
Wing Lian
2023-05-25 22:37:23 -04:00
parent 259262bf42
commit 34c99f9812
2 changed files with 3 additions and 3 deletions

View File

@@ -248,7 +248,7 @@ def load_model(
if (
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
) and not cfg.load_4bit:
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
@@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter):
if adapter is None:
return model, None
if adapter == "lora" or adapter == "qlora":
if adapter in ["lora" , "qlora"]:
return load_lora(model, cfg)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)

View File

@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
data_collator_kwargs = {