Add cfg.lora_target_linear

This commit is contained in:
NanoCode012
2023-05-26 14:32:30 +09:00
parent bbfc333a01
commit 919623793a
2 changed files with 13 additions and 8 deletions

View File

@@ -232,6 +232,7 @@ lora_target_modules:
# - gate_proj
# - down_proj
# - up_proj
lora_target_linear: # if true, will target all linear layers
lora_modules_to_save:
# - embed_tokens
# - lm_head

View File

@@ -364,14 +364,18 @@ def load_lora(model, cfg):
PeftModel,
)
bits = None
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
logging.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(list(cfg.lora_target_modules) + linear_names))
lora_target_modules = list(cfg.lora_target_modules)
if cfg.lora_target_linear:
bits = None
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
logging.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config = LoraConfig(
r=cfg.lora_r,