ensure enable_input_require_grads is called on model before getting the peft model (#345)

This commit is contained in:
Wing Lian
2023-08-06 18:13:10 -04:00
committed by GitHub
parent 3392270544
commit 176b888a63

View File

@@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter):
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg)
if adapter == "llama-adapter":