remove reference to deprecated import (#2407)

This commit is contained in:
Wing Lian
2025-03-15 08:49:41 -04:00
committed by GitHub
parent fbe54be6b8
commit 4f5eb42a73

View File

@@ -24,7 +24,6 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import nn
from transformers import ( # noqa: F401
AddedToken,
@@ -1360,7 +1359,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if (