From fe0b76854ec444643481da131228c8d214654f91 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 11 Jun 2023 09:20:40 -0400 Subject: [PATCH] match up gradient checkpointing when using lora w config --- src/axolotl/utils/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb363952c..b79f116fa 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -305,7 +305,9 @@ def load_model( or (cfg.adapter == "qlora" and cfg.load_in_4bit) ): logging.info("converting PEFT model w/ prepare_model_for_kbit_training") - model = prepare_model_for_kbit_training(model) + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=cfg.gradient_checkpointing + ) model, lora_config = load_adapter(model, cfg, adapter)