attempt to find linear modules for qlora
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, TYPE_CHECKING
|
from typing import Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
|
|||||||
return model, peft_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_linear_names(bits, model):
|
||||||
|
cls = (
|
||||||
|
bnb.nn.Linear4bit
|
||||||
|
if bits == 4
|
||||||
|
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
|
||||||
|
)
|
||||||
|
lora_module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, cls):
|
||||||
|
names = name.split(".")
|
||||||
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||||
|
|
||||||
|
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||||
|
lora_module_names.remove("lm_head")
|
||||||
|
|
||||||
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg):
|
def load_lora(model, cfg):
|
||||||
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
@@ -343,12 +362,15 @@ def load_lora(model, cfg):
|
|||||||
PeftModel,
|
PeftModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_config = None
|
bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None
|
||||||
|
linear_names = find_all_linear_names(bits, model)
|
||||||
|
logging.info(f"found linear modules: {repr(linear_names)}")
|
||||||
|
lora_target_modules = cfg.lora_target_modules + linear_names
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
lora_alpha=cfg.lora_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=cfg.lora_target_modules,
|
target_modules=lora_target_modules,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||||
|
|||||||
Reference in New Issue
Block a user