attempt to find linear modules for qlora
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def find_all_linear_names(bits, model):
|
||||
cls = (
|
||||
bnb.nn.Linear4bit
|
||||
if bits == 4
|
||||
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
|
||||
)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, cls):
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def load_lora(model, cfg):
|
||||
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
@@ -343,12 +362,15 @@ def load_lora(model, cfg):
|
||||
PeftModel,
|
||||
)
|
||||
|
||||
lora_config = None
|
||||
bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None
|
||||
linear_names = find_all_linear_names(bits, model)
|
||||
logging.info(f"found linear modules: {repr(linear_names)}")
|
||||
lora_target_modules = cfg.lora_target_modules + linear_names
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=cfg.lora_target_modules,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
|
||||
Reference in New Issue
Block a user