From ffd10436070c7459e87c7c136b8d7aeebf8b15ef Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 24 May 2023 23:18:08 -0400 Subject: [PATCH] attempt to find linear modules for qlora --- src/axolotl/utils/models.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index feec832b0..7e1d91aa1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -4,6 +4,7 @@ import os from pathlib import Path from typing import Optional, Tuple, TYPE_CHECKING +import bitsandbytes as bnb import torch import transformers from torch import nn @@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg): return model, peft_config +def find_all_linear_names(bits, model): + cls = ( + bnb.nn.Linear4bit + if bits == 4 + else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear) + ) + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + + return list(lora_module_names) + + def load_lora(model, cfg): # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] @@ -343,12 +362,15 @@ def load_lora(model, cfg): PeftModel, ) - lora_config = None + bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None + linear_names = find_all_linear_names(bits, model) + logging.info(f"found linear modules: {repr(linear_names)}") + lora_target_modules = cfg.lora_target_modules + linear_names lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, - target_modules=cfg.lora_target_modules, + target_modules=lora_target_modules, lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,