diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1805a749a..93d0f13c0 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -24,7 +24,6 @@ from peft import ( PeftModelForCausalLM, prepare_model_for_kbit_training, ) -from peft.tuners.lora import QuantLinear from torch import nn from transformers import ( # noqa: F401 AddedToken, @@ -1360,7 +1359,7 @@ def load_llama_adapter(model, cfg): def find_all_linear_names(model): - cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) + cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) lora_module_names = set() for name, module in model.named_modules(): if (