fixes to make qlora actually work
This commit is contained in:
@@ -248,7 +248,7 @@ def load_model(
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
||||||
) and not cfg.load_4bit:
|
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
|
||||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
@@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter):
|
|||||||
|
|
||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if adapter == "lora" or adapter == "qlora":
|
if adapter in ["lora" , "qlora"]:
|
||||||
return load_lora(model, cfg)
|
return load_lora(model, cfg)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
return load_llama_adapter(model, cfg)
|
return load_llama_adapter(model, cfg)
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
)
|
)
|
||||||
callbacks.append(early_stop_cb)
|
callbacks.append(early_stop_cb)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
|
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
||||||
callbacks.append(SavePeftModelCallback)
|
callbacks.append(SavePeftModelCallback)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
|
|||||||
Reference in New Issue
Block a user