fixes to make qlora actually work
This commit is contained in:
@@ -248,7 +248,7 @@ def load_model(
|
||||
|
||||
if (
|
||||
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
||||
) and not cfg.load_4bit:
|
||||
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
@@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if adapter == "lora" or adapter == "qlora":
|
||||
if adapter in ["lora" , "qlora"]:
|
||||
return load_lora(model, cfg)
|
||||
if adapter == "llama-adapter":
|
||||
return load_llama_adapter(model, cfg)
|
||||
|
||||
@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
|
||||
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
|
||||
Reference in New Issue
Block a user