diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c47b724a1..a5510d3bd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -408,9 +408,9 @@ def load_model( needs_fa2_dtype = cfg.adapter or cfg.fsdp if ( - (cfg.adapter == "lora" and load_in_8bit) + (cfg.adapter == "lora" and cfg.load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) - or (cfg.adapter == "ia3" and load_in_8bit) + or (cfg.adapter == "ia3" and cfg.load_in_8bit) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") if cfg.gradient_checkpointing: