Compare commits

...

1 Commits

Author SHA1 Message Date
NanoCode012
dce5bed379 feat: merge adapter in fp32 2026-03-14 00:20:59 +07:00
4 changed files with 4 additions and 4 deletions

View File

@@ -26,7 +26,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
LOG.info("Running merge of LoRA with base model...") LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True) model = model.merge_and_unload(progressbar=True, safe_merge=True)
try: try:
model.to(dtype=cfg.torch_dtype) model.to(dtype=cfg.torch_dtype)
except ValueError as e: except ValueError as e:

View File

@@ -226,7 +226,7 @@ class ModelLoader:
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM)) isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
and not self.is_qlora_and_fsdp_enabled and not self.is_qlora_and_fsdp_enabled
): ):
self.model = self.model.merge_and_unload() self.model = self.model.merge_and_unload(safe_merge=True)
self._configure_experts_implementation() self._configure_experts_implementation()
self._apply_activation_checkpointing() self._apply_activation_checkpointing()

View File

@@ -257,7 +257,7 @@ def save_trained_model(
# Handle ReLoRA early return case # Handle ReLoRA early return case
if cfg.relora: if cfg.relora:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload() model = model.merge_and_unload(safe_merge=True)
else: else:
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return return

View File

@@ -69,7 +69,7 @@ class TestAdapterMergeUnmerge:
self.scaling = alpha / r self.scaling = alpha / r
def mock_merge_and_unload(progressbar=False): def mock_merge_and_unload(progressbar=False, safe_merge=False):
"""Simulate the actual merge operation""" """Simulate the actual merge operation"""
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling # Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling