Fix future deprecate prepare_model_for_int8_training

This commit is contained in:
NanoCode012
2023-06-02 12:38:57 +09:00
parent 6abfd87d44
commit df9528f865

View File

@@ -128,7 +128,8 @@ def load_model(
)
replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
else:
from peft import prepare_model_for_kbit_training
except Exception as err:
logging.exception(err)
raise err
@@ -269,8 +270,8 @@ def load_model(
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(model)
model, lora_config = load_adapter(model, cfg, adapter)