Merge pull request #143 from NanoCode012/fix/deprecate-prepare-8bit-training

Fix future deprecate prepare_model_for_int8_training
This commit is contained in:
NanoCode012
2023-06-08 23:07:53 +09:00
committed by GitHub

View File

@@ -128,7 +128,8 @@ def load_model(
)
replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
else:
from peft import prepare_model_for_kbit_training
except Exception as err:
logging.exception(err)
raise err
@@ -269,8 +270,8 @@ def load_model(
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(model)
model, lora_config = load_adapter(model, cfg, adapter)