Compare commits
1 Commits
08fc7de87e
...
fix/merge-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dce5bed379 |
@@ -26,7 +26,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
LOG.info("Running merge of LoRA with base model...")
|
LOG.info("Running merge of LoRA with base model...")
|
||||||
model = model.merge_and_unload(progressbar=True)
|
model = model.merge_and_unload(progressbar=True, safe_merge=True)
|
||||||
try:
|
try:
|
||||||
model.to(dtype=cfg.torch_dtype)
|
model.to(dtype=cfg.torch_dtype)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ class ModelLoader:
|
|||||||
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
||||||
and not self.is_qlora_and_fsdp_enabled
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
):
|
):
|
||||||
self.model = self.model.merge_and_unload()
|
self.model = self.model.merge_and_unload(safe_merge=True)
|
||||||
|
|
||||||
self._configure_experts_implementation()
|
self._configure_experts_implementation()
|
||||||
self._apply_activation_checkpointing()
|
self._apply_activation_checkpointing()
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ def save_trained_model(
|
|||||||
# Handle ReLoRA early return case
|
# Handle ReLoRA early return case
|
||||||
if cfg.relora:
|
if cfg.relora:
|
||||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload(safe_merge=True)
|
||||||
else:
|
else:
|
||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class TestAdapterMergeUnmerge:
|
|||||||
|
|
||||||
self.scaling = alpha / r
|
self.scaling = alpha / r
|
||||||
|
|
||||||
def mock_merge_and_unload(progressbar=False):
|
def mock_merge_and_unload(progressbar=False, safe_merge=False):
|
||||||
"""Simulate the actual merge operation"""
|
"""Simulate the actual merge operation"""
|
||||||
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
|
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
|
||||||
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||||
|
|||||||
Reference in New Issue
Block a user