diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 063e43977..d0e5128ef 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -464,12 +464,8 @@ def load_llama_adapter(model, cfg): return model, peft_config -def find_all_linear_names(bits, model): - cls = ( - bnb.nn.Linear4bit - if bits == 4 - else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear) - ) +def find_all_linear_names(model): + cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): @@ -490,13 +486,7 @@ def load_lora(model, cfg): lora_target_modules = list(cfg.lora_target_modules or []) if cfg.lora_target_linear: - bits = None - if cfg.load_in_4bit: - bits = 4 - elif cfg.load_in_8bit: - bits = 8 - - linear_names = find_all_linear_names(bits, model) + linear_names = find_all_linear_names(model) LOG.info(f"found linear modules: {repr(linear_names)}") lora_target_modules = list(set(lora_target_modules + linear_names))