diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ad0f2df94..5b243bec4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -363,13 +363,13 @@ def load_lora(model, cfg): ) bits = None - if cfg.cfg.load_in_4bits: + if cfg.load_in_4bit: bits = 4 - elif cfg.cfg.load_in_8bits: + elif cfg.load_in_8bit: bits = 8 linear_names = find_all_linear_names(bits, model) logging.info(f"found linear modules: {repr(linear_names)}") - lora_target_modules = list(cfg.lora_target_modules) + linear_names + lora_target_modules = list(set(list(cfg.lora_target_modules) + linear_names)) lora_config = LoraConfig( r=cfg.lora_r,