Add cfg.lora_target_linear
This commit is contained in:
@@ -232,6 +232,7 @@ lora_target_modules:
|
|||||||
# - gate_proj
|
# - gate_proj
|
||||||
# - down_proj
|
# - down_proj
|
||||||
# - up_proj
|
# - up_proj
|
||||||
|
lora_target_linear: # if true, will target all linear layers
|
||||||
lora_modules_to_save:
|
lora_modules_to_save:
|
||||||
# - embed_tokens
|
# - embed_tokens
|
||||||
# - lm_head
|
# - lm_head
|
||||||
|
|||||||
@@ -364,14 +364,18 @@ def load_lora(model, cfg):
|
|||||||
PeftModel,
|
PeftModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
bits = None
|
lora_target_modules = list(cfg.lora_target_modules)
|
||||||
if cfg.load_in_4bit:
|
|
||||||
bits = 4
|
if cfg.lora_target_linear:
|
||||||
elif cfg.load_in_8bit:
|
bits = None
|
||||||
bits = 8
|
if cfg.load_in_4bit:
|
||||||
linear_names = find_all_linear_names(bits, model)
|
bits = 4
|
||||||
logging.info(f"found linear modules: {repr(linear_names)}")
|
elif cfg.load_in_8bit:
|
||||||
lora_target_modules = list(set(list(cfg.lora_target_modules) + linear_names))
|
bits = 8
|
||||||
|
|
||||||
|
linear_names = find_all_linear_names(bits, model)
|
||||||
|
logging.info(f"found linear modules: {repr(linear_names)}")
|
||||||
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
|
|||||||
Reference in New Issue
Block a user