ensure enable_input_require_grads is called on model before getting the peft model (#345)
This commit is contained in:
@@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg)
|
||||
if adapter == "llama-adapter":
|
||||
|
||||
Reference in New Issue
Block a user