fix load_in_8bit check
This commit is contained in:
@@ -408,9 +408,9 @@ def load_model(
|
|||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
if (
|
if (
|
||||||
(cfg.adapter == "lora" and load_in_8bit)
|
(cfg.adapter == "lora" and cfg.load_in_8bit)
|
||||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||||
or (cfg.adapter == "ia3" and load_in_8bit)
|
or (cfg.adapter == "ia3" and cfg.load_in_8bit)
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
|
|||||||
Reference in New Issue
Block a user