Add cfg.lora_target_linear
This commit is contained in:
@@ -364,14 +364,18 @@ def load_lora(model, cfg):
|
||||
PeftModel,
|
||||
)
|
||||
|
||||
bits = None
|
||||
if cfg.load_in_4bit:
|
||||
bits = 4
|
||||
elif cfg.load_in_8bit:
|
||||
bits = 8
|
||||
linear_names = find_all_linear_names(bits, model)
|
||||
logging.info(f"found linear modules: {repr(linear_names)}")
|
||||
lora_target_modules = list(set(list(cfg.lora_target_modules) + linear_names))
|
||||
lora_target_modules = list(cfg.lora_target_modules)
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
bits = None
|
||||
if cfg.load_in_4bit:
|
||||
bits = 4
|
||||
elif cfg.load_in_8bit:
|
||||
bits = 8
|
||||
|
||||
linear_names = find_all_linear_names(bits, model)
|
||||
logging.info(f"found linear modules: {repr(linear_names)}")
|
||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
|
||||
Reference in New Issue
Block a user