diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index bc2dc84c7..0a4b3c612 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -26,7 +26,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None: model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) LOG.info("Running merge of LoRA with base model...") - model = model.merge_and_unload(progressbar=True) + model = model.merge_and_unload(progressbar=True, safe_merge=True) try: model.to(dtype=cfg.torch_dtype) except ValueError as e: diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 03c1f35bc..1b5131e19 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -226,7 +226,7 @@ class ModelLoader: isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM)) and not self.is_qlora_and_fsdp_enabled ): - self.model = self.model.merge_and_unload() + self.model = self.model.merge_and_unload(safe_merge=True) self._configure_experts_implementation() self._apply_activation_checkpointing() diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 16c3696c0..143509ce8 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -257,7 +257,7 @@ def save_trained_model( # Handle ReLoRA early return case if cfg.relora: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): - model = model.merge_and_unload() + model = model.merge_and_unload(safe_merge=True) else: # final model weights have already been saved by `ReLoRACallback.on_train_end` return diff --git a/tests/utils/lora/test_merge_lora.py b/tests/utils/lora/test_merge_lora.py index 8edccafb9..830581e2a 100644 --- a/tests/utils/lora/test_merge_lora.py +++ b/tests/utils/lora/test_merge_lora.py @@ -69,7 +69,7 @@ class TestAdapterMergeUnmerge: self.scaling = alpha / r - def mock_merge_and_unload(progressbar=False): + def mock_merge_and_unload(progressbar=False, safe_merge=False): """Simulate the actual merge operation""" # Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling