From 919623793affbfe3d139f3f0ccd7643285dd5e60 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 26 May 2023 14:32:30 +0900 Subject: [PATCH] Add cfg.lora_target_linear --- README.md | 1 + src/axolotl/utils/models.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8a1c945d8..ae8e3e2c0 100644 --- a/README.md +++ b/README.md @@ -232,6 +232,7 @@ lora_target_modules: # - gate_proj # - down_proj # - up_proj +lora_target_linear: # if true, will target all linear layers lora_modules_to_save: # - embed_tokens # - lm_head diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 939a312d5..405c9e4b2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -364,14 +364,18 @@ def load_lora(model, cfg): PeftModel, ) - bits = None - if cfg.load_in_4bit: - bits = 4 - elif cfg.load_in_8bit: - bits = 8 - linear_names = find_all_linear_names(bits, model) - logging.info(f"found linear modules: {repr(linear_names)}") - lora_target_modules = list(set(list(cfg.lora_target_modules) + linear_names)) + lora_target_modules = list(cfg.lora_target_modules) + + if cfg.lora_target_linear: + bits = None + if cfg.load_in_4bit: + bits = 4 + elif cfg.load_in_8bit: + bits = 8 + + linear_names = find_all_linear_names(bits, model) + logging.info(f"found linear modules: {repr(linear_names)}") + lora_target_modules = list(set(lora_target_modules + linear_names)) lora_config = LoraConfig( r=cfg.lora_r,