feat: merge adapter in fp32
This commit is contained in:
@@ -26,7 +26,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
model = model.merge_and_unload(progressbar=True)
|
||||
model = model.merge_and_unload(progressbar=True, safe_merge=True)
|
||||
try:
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -226,7 +226,7 @@ class ModelLoader:
|
||||
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
):
|
||||
self.model = self.model.merge_and_unload()
|
||||
self.model = self.model.merge_and_unload(safe_merge=True)
|
||||
|
||||
self._configure_experts_implementation()
|
||||
self._apply_activation_checkpointing()
|
||||
|
||||
@@ -257,7 +257,7 @@ def save_trained_model(
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
model = model.merge_and_unload(safe_merge=True)
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestAdapterMergeUnmerge:
|
||||
|
||||
self.scaling = alpha / r
|
||||
|
||||
def mock_merge_and_unload(progressbar=False):
|
||||
def mock_merge_and_unload(progressbar=False, safe_merge=False):
|
||||
"""Simulate the actual merge operation"""
|
||||
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
|
||||
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
|
||||
Reference in New Issue
Block a user