Fix future deprecate prepare_model_for_int8_training

This commit is contained in:
NanoCode012
2023-06-02 12:38:57 +09:00
parent 6abfd87d44
commit df9528f865

View File

@@ -128,7 +128,8 @@ def load_model(
) )
replace_peft_model_with_int4_lora_model() replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training else:
from peft import prepare_model_for_kbit_training
except Exception as err: except Exception as err:
logging.exception(err) logging.exception(err)
raise err raise err
@@ -269,8 +270,8 @@ def load_model(
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
): ):
logging.info("converting PEFT model w/ prepare_model_for_int8_training") logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_int8_training(model) model = prepare_model_for_kbit_training(model)
model, lora_config = load_adapter(model, cfg, adapter) model, lora_config = load_adapter(model, cfg, adapter)