diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c133e9eb6..bccb8b8e5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -507,7 +507,11 @@ def find_all_linear_names(model): cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) lora_module_names = set() for name, module in model.named_modules(): - if isinstance(module, cls) or "Linear" in module.__class__.__name__: + if ( + isinstance(module, cls) + or "Linear" in module.__class__.__name__ + and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) + ): names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1])