cfg.cfg fix, also de-dupe lora module list

This commit is contained in:
Wing Lian
2023-05-25 09:18:57 -04:00
parent a617f1b65e
commit 676d7da661

View File

@@ -363,13 +363,13 @@ def load_lora(model, cfg):
) )
bits = None bits = None
if cfg.cfg.load_in_4bits: if cfg.load_in_4bits:
bits = 4 bits = 4
elif cfg.cfg.load_in_8bits: elif cfg.load_in_8bits:
bits = 8 bits = 8
linear_names = find_all_linear_names(bits, model) linear_names = find_all_linear_names(bits, model)
logging.info(f"found linear modules: {repr(linear_names)}") logging.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(cfg.lora_target_modules) + linear_names lora_target_modules = list(set(list(cfg.lora_target_modules) + linear_names))
lora_config = LoraConfig( lora_config = LoraConfig(
r=cfg.lora_r, r=cfg.lora_r,