diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb363952c..b79f116fa 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -305,7 +305,9 @@ def load_model( or (cfg.adapter == "qlora" and cfg.load_in_4bit) ): logging.info("converting PEFT model w/ prepare_model_for_kbit_training") - model = prepare_model_for_kbit_training(model) + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=cfg.gradient_checkpointing + ) model, lora_config = load_adapter(model, cfg, adapter)