fix load_in_8bit check

This commit is contained in:
Wing Lian
2023-09-18 18:51:56 -04:00
parent 1da328eb9a
commit c8e42a0f4f

View File

@@ -408,9 +408,9 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if (
(cfg.adapter == "lora" and load_in_8bit)
(cfg.adapter == "lora" and cfg.load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
or (cfg.adapter == "ia3" and load_in_8bit)
or (cfg.adapter == "ia3" and cfg.load_in_8bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: