attempt to find linear modules for qlora

This commit is contained in:
Wing Lian
2023-05-24 23:18:08 -04:00
parent 3369c4dcf8
commit ffd1043607

View File

@@ -4,6 +4,7 @@ import os
from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING
import bitsandbytes as bnb
import torch
import transformers
from torch import nn
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
return model, peft_config
def find_all_linear_names(bits, model):
cls = (
bnb.nn.Linear4bit
if bits == 4
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def load_lora(model, cfg):
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
@@ -343,12 +362,15 @@ def load_lora(model, cfg):
PeftModel,
)
lora_config = None
bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None
linear_names = find_all_linear_names(bits, model)
logging.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = cfg.lora_target_modules + linear_names
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules,
target_modules=lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,