From 1db6ad60a71ab852528762a59c60a4ce4f12717c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Feb 2025 22:56:34 -0500 Subject: [PATCH] support for passing init_lora_weights to lora_config (#2352) --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/models.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 26bfff7dc..1810413be 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -342,6 +342,7 @@ class LoraConfig(BaseModel): peft_use_dora: Optional[bool] = None peft_use_rslora: Optional[bool] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None + peft_init_lora_weights: Optional[Union[bool, str]] = None qlora_sharded_model_loading: Optional[bool] = Field( default=False, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 377f08605..c4c07dd33 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1321,6 +1321,8 @@ def load_lora(model, cfg, inference=False, config_only=False): if loftq_bits: lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["init_lora_weights"] = "loftq" + if cfg.peft_init_lora_weights: + lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights if cfg.peft_use_dora: lora_config_kwargs["use_dora"] = cfg.peft_use_dora LOG.info("Initializing LoRA weights using dora. This might take longer.")