remove reference to deprecated import (#2407)
This commit is contained in:
@@ -24,7 +24,6 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from peft.tuners.lora import QuantLinear
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
@@ -1360,7 +1359,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
Reference in New Issue
Block a user