diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 58e0e97ec..b778f17ac 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -128,7 +128,8 @@ def load_model( ) replace_peft_model_with_int4_lora_model() - from peft import prepare_model_for_int8_training + else: + from peft import prepare_model_for_kbit_training except Exception as err: logging.exception(err) raise err @@ -269,8 +270,8 @@ def load_model( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) ): - logging.info("converting PEFT model w/ prepare_model_for_int8_training") - model = prepare_model_for_int8_training(model) + logging.info("converting PEFT model w/ prepare_model_for_kbit_training") + model = prepare_model_for_kbit_training(model) model, lora_config = load_adapter(model, cfg, adapter)