diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 253bdcbd8..7501878ba 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() if adapter in ["lora", "qlora"]: return load_lora(model, cfg) if adapter == "llama-adapter":