feat:merge-lora iterate through bins without loading (#3095)
* merge_method added * merge_efficient core implement * Update src/axolotl/cli/merge_lora.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * Update src/axolotl/utils/lora_merge_efficient.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * standard to leagcy + rstrip + try/except for do_merge_lora_efficient(cfg=cfg) * fix: 'dict' object has no attribute 'lora_alpha' * into -> debug * lint * lint2 * moved everythign to cpu + peformance improvments * lint * Update src/axolotl/cli/merge_lora.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * Update src/axolotl/cli/merge_lora.py Co-authored-by: Dan Saunders <danjsaund@gmail.com> * string handeling + try except remove * merge_method -> merge_lora_methods * remove duplicate cal + safetensor + move to lora_merge.py * lint * handle quant-dequant, handle experts * fix parameter merging and prefer peft's native merge logic per module --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
@@ -4,9 +4,11 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -17,12 +19,26 @@ LOG = get_logger(__name__)
|
||||
@send_errors
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
along with the LoRA adapters to combine them into a single base model.
|
||||
Merges LoRA adapters with base model using either memory-efficient or legacy approach.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
merge_method = str(getattr(cfg, "merge_method", "memory_efficient"))
|
||||
if merge_method == "legacy":
|
||||
LOG.debug("Using legacy LoRA merging method...")
|
||||
_do_merge_lora_legacy(cfg=cfg)
|
||||
else:
|
||||
LOG.debug("Using memory-efficient LoRA merging method...")
|
||||
_do_merge_lora_efficient(cfg=cfg)
|
||||
|
||||
|
||||
def _do_merge_lora_legacy(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Legacy LoRA merging using merge_and_unload.
|
||||
Loads the full model into memory before merging.
|
||||
"""
|
||||
LOG.debug("Using legacy LoRA merging method...")
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
@@ -52,6 +68,58 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Memory-efficient LoRA merging using shard-by-shard processing.
|
||||
Does not load the full model into memory.
|
||||
|
||||
Supports standard LoRA, RSLoRA, and DoRA. Unsupported methods (AdaLoRA, VeRA)
|
||||
will raise NotImplementedError — use legacy method for those.
|
||||
"""
|
||||
LOG.debug("Using memory-efficient LoRA merging method...")
|
||||
|
||||
output_path = Path(cfg.output_dir) / "merged"
|
||||
safe_tensors = getattr(cfg, "save_safetensors", True)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Detect NF4 quantization from config to simulate QLoRA training dynamics.
|
||||
# Check both current and original (pre-override) config values since do_cli
|
||||
# forces load_in_4bit=False for the legacy path.
|
||||
simulate_nf4 = bool(
|
||||
getattr(cfg, "load_in_4bit", False)
|
||||
or getattr(cfg, "_original_load_in_4bit", False)
|
||||
or getattr(cfg, "adapter", None) == "qlora"
|
||||
or getattr(cfg, "_original_adapter", None) == "qlora"
|
||||
)
|
||||
|
||||
bnb_config_kwargs = getattr(cfg, "bnb_config_kwargs", None) or {}
|
||||
nf4_blocksize = bnb_config_kwargs.get("blocksize", None)
|
||||
nf4_double_quant = bnb_config_kwargs.get(
|
||||
"bnb_4bit_use_double_quant",
|
||||
getattr(cfg, "bnb_4bit_use_double_quant", True),
|
||||
)
|
||||
|
||||
# Detect MoE expert quantization
|
||||
simulate_nf4_experts = bool(
|
||||
getattr(cfg, "quantize_moe_experts", False)
|
||||
or getattr(cfg, "_original_quantize_moe_experts", False)
|
||||
)
|
||||
|
||||
merge_lora_sharded_efficient(
|
||||
base_model_path=cfg.base_model,
|
||||
lora_adapter_path=cfg.lora_model_dir,
|
||||
output_path=output_path,
|
||||
safe_tensors=safe_tensors,
|
||||
device=device,
|
||||
simulate_nf4=simulate_nf4,
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
)
|
||||
|
||||
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
|
||||
@@ -66,6 +134,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
ValueError: If target directory for LoRA merged model does not exist.
|
||||
"""
|
||||
|
||||
# Pre-load config to detect original quantization settings before overrides
|
||||
raw_cfg = load_cfg(config, **kwargs)
|
||||
original_load_in_4bit = getattr(raw_cfg, "load_in_4bit", False)
|
||||
original_adapter = getattr(raw_cfg, "adapter", None)
|
||||
original_quantize_moe_experts = getattr(raw_cfg, "quantize_moe_experts", False)
|
||||
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
merge_lora=True,
|
||||
@@ -80,11 +154,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Stash original quantization settings for NF4 simulation in efficient merge
|
||||
parsed_cfg._original_load_in_4bit = original_load_in_4bit
|
||||
parsed_cfg._original_adapter = original_adapter
|
||||
parsed_cfg._original_quantize_moe_experts = original_quantize_moe_experts
|
||||
|
||||
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
|
||||
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
|
||||
if not Path(parsed_cfg.lora_model_dir).exists():
|
||||
raise ValueError(
|
||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||
f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"
|
||||
)
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg)
|
||||
|
||||
982
src/axolotl/cli/utils/lora_merge.py
Normal file
982
src/axolotl/cli/utils/lora_merge.py
Normal file
@@ -0,0 +1,982 @@
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _simulate_nf4_roundtrip(
|
||||
tensor: torch.Tensor,
|
||||
blocksize: Optional[int] = None,
|
||||
compress_statistics: bool = True,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simulate NF4 quantization roundtrip to match QLoRA training dynamics.
|
||||
|
||||
During QLoRA training, base weights are quantized to NF4 and dequantized on-the-fly
|
||||
for each forward pass. The LoRA adapters learn to compensate for the quantization
|
||||
noise in the dequantized weights. To match this at merge time, we apply the same
|
||||
quantize → dequantize roundtrip so the merged result reflects what the model saw
|
||||
during training.
|
||||
|
||||
Args:
|
||||
tensor: Base model weight tensor (fp16/bf16/fp32)
|
||||
blocksize: NF4 quantization block size (default: bitsandbytes default)
|
||||
compress_statistics: Whether to use double quantization
|
||||
device: Device for quantization computation. bitsandbytes requires a
|
||||
CUDA device; defaults to "cuda" when available.
|
||||
|
||||
Returns:
|
||||
Tensor after NF4 quantize → dequantize roundtrip, in original dtype
|
||||
"""
|
||||
import bitsandbytes.functional as bnb_F
|
||||
|
||||
quant_device: torch.device
|
||||
if device is None:
|
||||
quant_device = torch.device("cuda")
|
||||
elif isinstance(device, str):
|
||||
quant_device = torch.device(device)
|
||||
else:
|
||||
quant_device = device
|
||||
|
||||
if quant_device.type == "cuda" and not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"NF4 simulation requires CUDA but no GPU is available. "
|
||||
"Either run on a machine with a GPU or disable NF4 simulation."
|
||||
)
|
||||
|
||||
original_dtype = tensor.dtype
|
||||
original_shape = tensor.shape
|
||||
|
||||
# bitsandbytes requires float32 input for quantization and contiguous+CUDA tensor
|
||||
flat = tensor.reshape(-1).to(torch.float32).contiguous().to(quant_device)
|
||||
|
||||
quant_kwargs = {
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": compress_statistics,
|
||||
}
|
||||
if blocksize is not None:
|
||||
quant_kwargs["blocksize"] = blocksize
|
||||
|
||||
quantized, quant_state = bnb_F.quantize_4bit(flat, **quant_kwargs)
|
||||
dequantized = bnb_F.dequantize_4bit(quantized, quant_state, quant_type="nf4")
|
||||
|
||||
return dequantized.reshape(original_shape).to(original_dtype).cpu()
|
||||
|
||||
|
||||
def find_lora_weights(
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Find corresponding LoRA A and B weights for a given key.
|
||||
|
||||
Also tries keys after applying weight renamings (from transformers v5
|
||||
conversion mappings) in case the checkpoint key names differ from the
|
||||
runtime model key names used by the LoRA adapter.
|
||||
"""
|
||||
import re
|
||||
|
||||
clean_key = key[:-7] if key.endswith(".weight") else key
|
||||
|
||||
# Try the direct key first
|
||||
a_key = f"base_model.model.{clean_key}.lora_A.weight"
|
||||
b_key = f"base_model.model.{clean_key}.lora_B.weight"
|
||||
|
||||
lora_a = lora_state.get(a_key)
|
||||
lora_b = lora_state.get(b_key)
|
||||
|
||||
if lora_a is not None and lora_b is not None:
|
||||
return lora_a, lora_b
|
||||
|
||||
# Try renamed keys (checkpoint format → runtime format)
|
||||
if weight_renamings:
|
||||
for src_pattern, tgt_pattern in weight_renamings.items():
|
||||
renamed_key = re.sub(src_pattern, tgt_pattern, clean_key)
|
||||
if renamed_key != clean_key:
|
||||
a_key = f"base_model.model.{renamed_key}.lora_A.weight"
|
||||
b_key = f"base_model.model.{renamed_key}.lora_B.weight"
|
||||
lora_a = lora_state.get(a_key)
|
||||
lora_b = lora_state.get(b_key)
|
||||
if lora_a is not None and lora_b is not None:
|
||||
return lora_a, lora_b
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _find_param_wrapper_lora(
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
tensor_shape: Optional[tuple] = None,
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[str]]:
|
||||
"""
|
||||
Find LoRA weights from a ParamWrapper (lora_target_parameters) that targets
|
||||
a parent module containing this weight as a sub-parameter.
|
||||
|
||||
For example, base weight key 'model.layers.0.mlp.experts.down_proj' may have
|
||||
LoRA at 'base_model.model.model.layers.0.mlp.experts.lora_A.weight' (targeting
|
||||
the 'experts' module with 'down_proj' as the parameter_name).
|
||||
|
||||
When tensor_shape is provided, validates that the LoRA dimensions match the
|
||||
target tensor (important when multiple ParamWrappers are nested and each
|
||||
nesting level has different LoRA dimensions).
|
||||
|
||||
Returns (lora_A, lora_B, parameter_name) or (None, None, None).
|
||||
"""
|
||||
clean_key = key[:-7] if key.endswith(".weight") else key
|
||||
# Strip trailing parameter name to get the parent module path
|
||||
# e.g., "model.layers.0.mlp.experts.down_proj" → parent="model.layers.0.mlp.experts", param="down_proj"
|
||||
parts = clean_key.rsplit(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None, None, None
|
||||
|
||||
parent_key, param_name = parts
|
||||
|
||||
# PEFT's ParamWrapper nesting: when multiple parameters are targeted on
|
||||
# the same module, it nests wrappers. The outer wrapper's LoRA is at
|
||||
# parent.lora_A/B and inner wrappers use parent.base_layer.lora_A/B,
|
||||
# parent.base_layer.base_layer.lora_A/B, etc.
|
||||
prefixes_to_try = [
|
||||
f"base_model.model.{parent_key}",
|
||||
]
|
||||
# Walk up .base_layer nesting levels (typically 1-2 deep)
|
||||
for depth in range(1, 4):
|
||||
bl = ".base_layer" * depth
|
||||
prefixes_to_try.append(f"base_model.model.{parent_key}{bl}")
|
||||
|
||||
for prefix in prefixes_to_try:
|
||||
a_key = f"{prefix}.lora_A.weight"
|
||||
b_key = f"{prefix}.lora_B.weight"
|
||||
lora_a = lora_state.get(a_key)
|
||||
lora_b = lora_state.get(b_key)
|
||||
if lora_a is None or lora_b is None:
|
||||
continue
|
||||
|
||||
# When tensor_shape is given, verify dimensions match before returning.
|
||||
# This prevents returning a mismatched LoRA from a different nesting level.
|
||||
if tensor_shape is not None and len(tensor_shape) >= 3:
|
||||
num_experts = tensor_shape[0]
|
||||
if not (
|
||||
lora_a.shape[0] == lora_b.shape[1]
|
||||
and lora_a.shape[0] % num_experts == 0
|
||||
and lora_a.shape[1] == tensor_shape[1]
|
||||
and lora_b.shape[0] == tensor_shape[2]
|
||||
):
|
||||
continue # Dimensions don't match, try next nesting level
|
||||
|
||||
return lora_a, lora_b, param_name
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
def _build_peft_layer_and_get_delta(
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_config_dict: Dict,
|
||||
base_tensor: torch.Tensor,
|
||||
adapter_name: str = "default",
|
||||
is_param_wrapper: bool = False,
|
||||
magnitude: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Use PEFT's own layer classes to compute the LoRA delta weight.
|
||||
|
||||
Instead of re-implementing the merge math for every LoRA variant, this
|
||||
constructs a lightweight PEFT layer, loads the A/B weights, and calls
|
||||
``get_delta_weight`` (or ``merge`` for DoRA) which handles standard LoRA,
|
||||
RSLoRA, DoRA, and ParamWrapper (expert-blocked) LoRA.
|
||||
|
||||
Returns the delta tensor (same shape as base_tensor).
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
r_total = lora_a.shape[0]
|
||||
in_features = lora_a.shape[1]
|
||||
out_features = lora_b.shape[0]
|
||||
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
|
||||
|
||||
if is_param_wrapper:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
|
||||
num_experts = base_tensor.shape[0]
|
||||
r = r_total // num_experts
|
||||
|
||||
class _FakeModule(nn.Module):
|
||||
pass
|
||||
|
||||
fake = _FakeModule()
|
||||
fake.register_parameter(
|
||||
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
layer = ParamWrapper(
|
||||
fake,
|
||||
adapter_name=adapter_name,
|
||||
parameter_name="weight",
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
use_rslora=use_rslora,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
else:
|
||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||
|
||||
base_layer = nn.Linear(in_features, out_features, bias=False)
|
||||
base_layer.weight.data = base_tensor.clone()
|
||||
|
||||
fan_in_fan_out = bool(
|
||||
lora_config_dict.get("fan_in_fan_out", False)
|
||||
or lora_config_dict.get("lora_fan_in_fan_out", False)
|
||||
)
|
||||
|
||||
layer = LoraLinear(
|
||||
base_layer,
|
||||
adapter_name=adapter_name,
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
|
||||
if use_dora:
|
||||
# DoRA merges magnitude normalization into the weight directly.
|
||||
# Use PEFT's merge() which handles DoRA internally, then
|
||||
# compute the delta as merged_weight - original_weight.
|
||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||
mag_layer.weight = nn.Parameter(magnitude)
|
||||
layer.merge(adapter_names=[adapter_name])
|
||||
return base_layer.weight.data - base_tensor
|
||||
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
|
||||
|
||||
def get_model_shards(model_path: Path) -> list[Path]:
|
||||
"""Find all model shards in the given path."""
|
||||
shards: list[Path] = []
|
||||
|
||||
patterns = ["model*.safetensors", "pytorch_model*.bin"]
|
||||
|
||||
for pattern in patterns:
|
||||
shards.extend(model_path.glob(pattern))
|
||||
if shards:
|
||||
break
|
||||
|
||||
return sorted(shards)
|
||||
|
||||
|
||||
def copy_non_model_files(
|
||||
input_path: Path, output_path: Path, model_shards: list[Path]
|
||||
) -> None:
|
||||
"""
|
||||
Copy all non-model files to the output directory.
|
||||
|
||||
Args:
|
||||
input_path: Source directory
|
||||
output_path: Destination directory
|
||||
model_shards: List of model shard files to skip
|
||||
"""
|
||||
LOG.info("Copying non-model files to output directory...")
|
||||
|
||||
shard_names = {shard.name for shard in model_shards}
|
||||
|
||||
for filepath in input_path.glob("*"):
|
||||
if filepath.is_dir():
|
||||
continue
|
||||
if filepath.name in shard_names:
|
||||
continue
|
||||
if (
|
||||
filepath.name.startswith("model") and filepath.suffix == ".safetensors"
|
||||
) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"):
|
||||
continue
|
||||
if filepath.suffix == ".gguf":
|
||||
continue
|
||||
# Skip weight-map index files — they reference shard filenames that may
|
||||
# change during the merge (e.g. .bin → .safetensors). A correct index
|
||||
# is regenerated after all shards have been written.
|
||||
if filepath.name.endswith(".index.json"):
|
||||
continue
|
||||
|
||||
LOG.debug(f"Copying {filepath.name} to output")
|
||||
shutil.copy2(filepath, output_path)
|
||||
|
||||
|
||||
def _find_dora_magnitude(
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Find DoRA magnitude vector for a given key.
|
||||
"""
|
||||
import re
|
||||
|
||||
clean_key = key[:-7] if key.endswith(".weight") else key
|
||||
mag_key = f"base_model.model.{clean_key}.lora_magnitude_vector"
|
||||
result = lora_state.get(mag_key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if weight_renamings:
|
||||
for src_pattern, tgt_pattern in weight_renamings.items():
|
||||
renamed_key = re.sub(src_pattern, tgt_pattern, clean_key)
|
||||
if renamed_key != clean_key:
|
||||
mag_key = f"base_model.model.{renamed_key}.lora_magnitude_vector"
|
||||
result = lora_state.get(mag_key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _should_nf4_roundtrip(
|
||||
key: str,
|
||||
tensor: torch.Tensor,
|
||||
simulate_nf4: bool,
|
||||
simulate_nf4_experts: bool,
|
||||
) -> bool:
|
||||
"""Determine if a tensor should undergo NF4 quantization roundtrip."""
|
||||
if tensor.ndim < 2:
|
||||
return False
|
||||
if simulate_nf4:
|
||||
return True
|
||||
if simulate_nf4_experts and tensor.ndim >= 3 and "expert" in key.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _merge_tensor_with_lora(
|
||||
tensor: torch.Tensor,
|
||||
key: str,
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
scale: float,
|
||||
lora_config_dict: Dict,
|
||||
device: str,
|
||||
simulate_nf4: bool = False,
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[torch.Tensor, bool]:
|
||||
"""
|
||||
Helper function to merge a single tensor with its corresponding LoRA weights.
|
||||
|
||||
Args:
|
||||
tensor: Base model tensor
|
||||
key: Tensor key/name
|
||||
lora_state: Dictionary containing LoRA weights
|
||||
scale: LoRA scaling factor (alpha/r)
|
||||
lora_config_dict: LoRA configuration dictionary
|
||||
device: Device to perform computations on
|
||||
simulate_nf4: Whether to simulate NF4 quantization roundtrip for all weights
|
||||
simulate_nf4_experts: Whether to simulate NF4 roundtrip for MoE expert tensors only
|
||||
nf4_blocksize: Block size for NF4 quantization
|
||||
nf4_double_quant: Whether to use double quantization
|
||||
use_dora: Whether to apply DoRA (Weight-Decomposed LoRA) merging
|
||||
weight_renamings: Optional key renamings from transformers conversion mapping
|
||||
|
||||
Returns:
|
||||
Tuple of (merged tensor, whether LoRA was applied)
|
||||
"""
|
||||
lora_a, lora_b = find_lora_weights(lora_state, key, weight_renamings)
|
||||
|
||||
do_nf4 = _should_nf4_roundtrip(key, tensor, simulate_nf4, simulate_nf4_experts)
|
||||
|
||||
if lora_a is not None and lora_b is not None:
|
||||
LOG.debug(f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}")
|
||||
|
||||
original_dtype = tensor.dtype
|
||||
|
||||
# Simulate NF4 quantization roundtrip to match QLoRA training dynamics
|
||||
if do_nf4:
|
||||
tensor = _simulate_nf4_roundtrip(
|
||||
tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
|
||||
magnitude = (
|
||||
_find_dora_magnitude(lora_state, key, weight_renamings)
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
)
|
||||
merged_tensor = (
|
||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||
.to(original_dtype)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
return merged_tensor, True
|
||||
else:
|
||||
# Try ParamWrapper LoRA (lora_target_parameters) — the LoRA targets a
|
||||
# parent module and this weight is a sub-parameter of that module.
|
||||
if tensor.ndim >= 3:
|
||||
pw_a, pw_b, param_name = _find_param_wrapper_lora(
|
||||
lora_state, key, tensor_shape=tuple(tensor.shape)
|
||||
)
|
||||
if pw_a is not None and pw_b is not None:
|
||||
LOG.debug(
|
||||
f"Merging ParamWrapper LoRA for {key} "
|
||||
f"(param={param_name}): {pw_a.shape}, {pw_b.shape}"
|
||||
)
|
||||
if do_nf4:
|
||||
tensor = _simulate_nf4_roundtrip(
|
||||
tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
original_dtype = tensor.dtype
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
pw_a.to(device),
|
||||
pw_b.to(device),
|
||||
lora_config_dict,
|
||||
tensor.to(device),
|
||||
is_param_wrapper=True,
|
||||
)
|
||||
merged = (
|
||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||
.to(original_dtype)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
return merged, True
|
||||
|
||||
if do_nf4:
|
||||
tensor = _simulate_nf4_roundtrip(
|
||||
tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
return tensor.detach().cpu(), False
|
||||
|
||||
|
||||
def _get_conversion_info(base_model_path: Path) -> tuple[Dict[str, str], list]:
|
||||
"""
|
||||
Load the model's config.json and check if transformers has WeightRenaming
|
||||
or WeightConverter mappings for this model type.
|
||||
|
||||
Returns:
|
||||
- dict of {source_pattern: target_pattern} for simple renamings
|
||||
- list of WeightConverter objects for fuse/unfuse operations
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
config_path = base_model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
return {}, []
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
model_config = _json.load(f)
|
||||
except (OSError, _json.JSONDecodeError):
|
||||
return {}, []
|
||||
|
||||
model_type = model_config.get("model_type")
|
||||
if not model_type:
|
||||
return {}, []
|
||||
|
||||
try:
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
from transformers.core_model_loading import WeightConverter, WeightRenaming
|
||||
except ImportError:
|
||||
return {}, []
|
||||
|
||||
conversions = get_checkpoint_conversion_mapping(model_type)
|
||||
if not conversions:
|
||||
return {}, []
|
||||
|
||||
renamings = {}
|
||||
weight_converters = []
|
||||
for conv in conversions:
|
||||
if isinstance(conv, WeightRenaming):
|
||||
# WeightRenaming stores patterns as lists internally
|
||||
src_list = (
|
||||
conv.source_patterns
|
||||
if isinstance(conv.source_patterns, list)
|
||||
else [conv.source_patterns]
|
||||
)
|
||||
tgt_list = (
|
||||
conv.target_patterns
|
||||
if isinstance(conv.target_patterns, list)
|
||||
else [conv.target_patterns]
|
||||
)
|
||||
if len(src_list) == 1 and len(tgt_list) == 1:
|
||||
renamings[src_list[0]] = tgt_list[0]
|
||||
elif isinstance(conv, WeightConverter):
|
||||
weight_converters.append(conv)
|
||||
|
||||
return renamings, weight_converters
|
||||
|
||||
|
||||
def _fuse_and_unfuse_with_merge(
|
||||
shard_tensors: Dict[str, torch.Tensor],
|
||||
weight_converters: list,
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
scale: float,
|
||||
lora_config_dict: Dict,
|
||||
device: str,
|
||||
simulate_nf4: bool = False,
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
||||
"""
|
||||
For tensors matching WeightConverter patterns (MoE expert weights):
|
||||
1. Fuse checkpoint-format tensors into runtime-format (e.g., per-expert → fused 3D)
|
||||
2. Apply NF4 roundtrip + LoRA merge on the fused tensor
|
||||
3. Unfuse back to checkpoint format for saving
|
||||
|
||||
Returns:
|
||||
- Updated tensor dict
|
||||
- Count of merged LoRA targets
|
||||
- Set of keys that were processed (fused/merged/unfused) and should be
|
||||
skipped by the per-tensor merge pass to avoid double NF4 roundtrip
|
||||
"""
|
||||
import re
|
||||
|
||||
from transformers.core_model_loading import Concatenate, MergeModulelist
|
||||
|
||||
result = dict(shard_tensors) # Start with all tensors
|
||||
merged_count = 0
|
||||
processed_keys: set = set() # Keys that were fuse/unfuse processed
|
||||
|
||||
for converter in weight_converters:
|
||||
src_patterns = (
|
||||
converter.source_patterns
|
||||
if isinstance(converter.source_patterns, list)
|
||||
else [converter.source_patterns]
|
||||
)
|
||||
tgt_patterns = (
|
||||
converter.target_patterns
|
||||
if isinstance(converter.target_patterns, list)
|
||||
else [converter.target_patterns]
|
||||
)
|
||||
|
||||
# Build regex for each source pattern
|
||||
pattern_regexes = []
|
||||
for pat in src_patterns:
|
||||
regex_str = re.escape(pat).replace(r"\.\*\.", r"\.(\d+)\.")
|
||||
regex_str = (
|
||||
regex_str.rstrip(r"\$") if regex_str.endswith(r"\$") else regex_str
|
||||
)
|
||||
pattern_regexes.append(re.compile(r"(.*\.)?" + regex_str + "$"))
|
||||
|
||||
# Group matching keys by layer prefix and source pattern
|
||||
# {layer_prefix: {pat_idx: {expert_idx: (key, tensor)}}}
|
||||
layer_groups: Dict[str, Dict[int, Dict[int, tuple[str, torch.Tensor]]]] = {}
|
||||
|
||||
for key in list(result.keys()):
|
||||
for pat_idx, pat_regex in enumerate(pattern_regexes):
|
||||
match = pat_regex.match(key)
|
||||
if match:
|
||||
prefix = match.group(1) or ""
|
||||
# Extract expert index from the matched portion
|
||||
remaining = key[len(prefix) :]
|
||||
expert_match = re.search(r"\.(\d+)\.", remaining)
|
||||
expert_idx = int(expert_match.group(1)) if expert_match else 0
|
||||
|
||||
layer_groups.setdefault(prefix, {}).setdefault(pat_idx, {})[
|
||||
expert_idx
|
||||
] = (key, result[key])
|
||||
break
|
||||
|
||||
# Process each layer group
|
||||
for prefix, pat_groups in layer_groups.items():
|
||||
# Check we have all source patterns for this layer
|
||||
if not pat_groups:
|
||||
continue
|
||||
|
||||
# Step 1: Fuse — MergeModulelist (stack experts) per source pattern
|
||||
fused_per_pattern = {}
|
||||
original_keys_per_pattern: Dict[int, list[str]] = {}
|
||||
num_experts = None
|
||||
|
||||
for pat_idx in sorted(pat_groups.keys()):
|
||||
expert_data = pat_groups[pat_idx]
|
||||
sorted_indices = sorted(expert_data.keys())
|
||||
if num_experts is None:
|
||||
num_experts = len(sorted_indices)
|
||||
|
||||
sorted_tensors = [expert_data[idx][1] for idx in sorted_indices]
|
||||
original_keys_per_pattern[pat_idx] = [
|
||||
expert_data[idx][0] for idx in sorted_indices
|
||||
]
|
||||
fused_per_pattern[src_patterns[pat_idx]] = torch.stack(
|
||||
sorted_tensors, dim=0
|
||||
)
|
||||
|
||||
# Apply remaining operations (Concatenate)
|
||||
fused_tensor = None
|
||||
has_concat = False
|
||||
concat_dim = 1 # default
|
||||
|
||||
for op in converter.operations:
|
||||
if isinstance(op, MergeModulelist):
|
||||
pass # Already handled
|
||||
elif isinstance(op, Concatenate):
|
||||
has_concat = True
|
||||
concat_dim = op.dim
|
||||
tensors_to_cat = [
|
||||
fused_per_pattern[sp]
|
||||
for sp in src_patterns
|
||||
if sp in fused_per_pattern
|
||||
]
|
||||
if len(tensors_to_cat) > 1:
|
||||
fused_tensor = torch.cat(tensors_to_cat, dim=concat_dim)
|
||||
elif tensors_to_cat:
|
||||
fused_tensor = tensors_to_cat[0]
|
||||
|
||||
if not has_concat and len(fused_per_pattern) == 1:
|
||||
fused_tensor = next(iter(fused_per_pattern.values()))
|
||||
|
||||
if fused_tensor is None:
|
||||
continue
|
||||
|
||||
# Step 2: Build the fused key name and merge LoRA
|
||||
fused_key = prefix + tgt_patterns[0]
|
||||
|
||||
# Apply NF4 roundtrip on the fused tensor (matching training dynamics)
|
||||
do_nf4 = _should_nf4_roundtrip(
|
||||
fused_key, fused_tensor, simulate_nf4, simulate_nf4_experts
|
||||
)
|
||||
if do_nf4:
|
||||
fused_tensor = _simulate_nf4_roundtrip(
|
||||
fused_tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Try to find and merge LoRA weights for the fused key
|
||||
lora_a, lora_b = find_lora_weights(lora_state, fused_key, weight_renamings)
|
||||
if lora_a is not None and lora_b is not None:
|
||||
LOG.debug(
|
||||
f"Merging LoRA for fused key {fused_key}: {lora_a.shape}, {lora_b.shape}"
|
||||
)
|
||||
original_dtype = fused_tensor.dtype
|
||||
magnitude = (
|
||||
_find_dora_magnitude(lora_state, fused_key, weight_renamings)
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
fused_tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
)
|
||||
fused_tensor = (
|
||||
(
|
||||
fused_tensor.to(device).to(torch.float32)
|
||||
+ delta.to(torch.float32)
|
||||
)
|
||||
.to(original_dtype)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
merged_count += 1
|
||||
|
||||
# Step 3: Save in fused format (runtime format) so that the merged
|
||||
# model can be loaded directly without needing WeightConverter
|
||||
# fusion during from_pretrained (which can OOM for large MoE models).
|
||||
# Remove the original per-expert keys and save the fused tensor
|
||||
# under the runtime key name.
|
||||
for pat_idx in sorted(original_keys_per_pattern.keys()):
|
||||
for ok in original_keys_per_pattern[pat_idx]:
|
||||
result.pop(ok, None)
|
||||
processed_keys.add(ok)
|
||||
|
||||
result[fused_key] = fused_tensor.detach().cpu()
|
||||
processed_keys.add(fused_key)
|
||||
|
||||
return result, merged_count, processed_keys
|
||||
|
||||
|
||||
def merge_lora_sharded_efficient(
|
||||
base_model_path: Union[str, Path],
|
||||
lora_adapter_path: Union[str, Path],
|
||||
output_path: Union[str, Path],
|
||||
device: str = "cpu",
|
||||
safe_tensors: bool = True,
|
||||
simulate_nf4: bool = False,
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Memory-efficient LoRA merging that processes shards individually
|
||||
without loading the full model into memory.
|
||||
|
||||
Args:
|
||||
simulate_nf4: Apply NF4 roundtrip to ALL weight tensors (for QLoRA)
|
||||
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
||||
(for quantize_moe_experts). Expert tensors are identified by having
|
||||
"expert" in the key name and ndim >= 3.
|
||||
"""
|
||||
base_model_path = Path(base_model_path)
|
||||
lora_adapter_path = Path(lora_adapter_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
if "/" in str(base_model_path) and not base_model_path.exists():
|
||||
base_model_path = Path(snapshot_download(str(base_model_path)))
|
||||
|
||||
# Check for weight conversion requirements (transformers v5)
|
||||
weight_renamings, weight_converters = _get_conversion_info(base_model_path)
|
||||
if weight_renamings:
|
||||
LOG.debug(f"Found {len(weight_renamings)} weight renamings for this model type")
|
||||
if weight_converters:
|
||||
LOG.debug(
|
||||
f"Found {len(weight_converters)} weight converters (fuse/unfuse) for this model type. "
|
||||
f"Will fuse→merge→unfuse within each shard."
|
||||
)
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
config_file = lora_adapter_path / "adapter_config.json"
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"LoRA config not found: {config_file}")
|
||||
|
||||
lora_config_dict = LoraConfig.from_json_file(str(config_file))
|
||||
if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0:
|
||||
raise ValueError("LoRA config 'r' must be > 0")
|
||||
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||
|
||||
unsupported_methods = []
|
||||
|
||||
# Check for AdaLoRA (Adaptive LoRA)
|
||||
if lora_config_dict.get("use_adalora", False):
|
||||
unsupported_methods.append("AdaLoRA (Adaptive LoRA)")
|
||||
|
||||
# Check for VeRA (Vector-based Random Matrix Adaptation)
|
||||
if lora_config_dict.get("use_vera", False):
|
||||
unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)")
|
||||
|
||||
# Check for other advanced LoRA variants by task_type
|
||||
task_type = lora_config_dict.get("task_type", "")
|
||||
if task_type and task_type not in [
|
||||
"CAUSAL_LM",
|
||||
"SEQ_2_SEQ_LM",
|
||||
"TOKEN_CLS",
|
||||
"SEQ_CLS",
|
||||
"QUESTION_ANS",
|
||||
]:
|
||||
unsupported_methods.append(f"Task type: {task_type}")
|
||||
|
||||
# Check for rank adaptation patterns (AdaLoRA indicators)
|
||||
# Use .get() so empty dicts/None don't false-positive
|
||||
if any(
|
||||
lora_config_dict.get(key)
|
||||
for key in ["rank_pattern", "alpha_pattern", "target_rank"]
|
||||
):
|
||||
unsupported_methods.append("AdaLoRA (rank adaptation detected)")
|
||||
|
||||
# Check for advanced initialization methods
|
||||
init_lora_weights = lora_config_dict.get("init_lora_weights", "")
|
||||
if init_lora_weights and init_lora_weights not in [
|
||||
"gaussian",
|
||||
"loftq",
|
||||
True,
|
||||
False,
|
||||
]:
|
||||
unsupported_methods.append(f"Advanced initialization: {init_lora_weights}")
|
||||
|
||||
if unsupported_methods:
|
||||
methods_str = ", ".join(unsupported_methods)
|
||||
raise NotImplementedError(
|
||||
f"Memory-efficient LoRA merge only supports standard LoRA. "
|
||||
f"Detected unsupported methods: {methods_str}. "
|
||||
f"Please use the legacy merge method for advanced LoRA variants."
|
||||
)
|
||||
|
||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||
if use_rslora:
|
||||
scale = float(lora_config_dict["lora_alpha"]) / math.sqrt(
|
||||
float(lora_config_dict["r"])
|
||||
)
|
||||
else:
|
||||
scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"])
|
||||
|
||||
LOG.debug(f"LoRA scale factor: {scale} (rslora={use_rslora})")
|
||||
|
||||
if simulate_nf4:
|
||||
LOG.info(
|
||||
"NF4 simulation enabled: base weights will undergo quantize→dequantize "
|
||||
"roundtrip before LoRA merge to match QLoRA training dynamics"
|
||||
)
|
||||
|
||||
lora_file = lora_adapter_path / "adapter_model.safetensors"
|
||||
if not lora_file.exists():
|
||||
lora_file = lora_adapter_path / "adapter_model.bin"
|
||||
if not lora_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"LoRA adapter weights not found in {lora_adapter_path}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Loading LoRA weights from {lora_file}")
|
||||
|
||||
if lora_file.suffix == ".safetensors":
|
||||
lora_state = safetensors.torch.load_file(lora_file)
|
||||
else:
|
||||
lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614
|
||||
LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")
|
||||
|
||||
model_shards = get_model_shards(base_model_path)
|
||||
if not model_shards:
|
||||
raise FileNotFoundError(f"No model shards found in {base_model_path}")
|
||||
|
||||
LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}")
|
||||
copy_non_model_files(base_model_path, output_path, model_shards)
|
||||
|
||||
merged_count = 0
|
||||
total_tensors = 0
|
||||
# Track weight_map for index regeneration: {tensor_key: shard_filename}
|
||||
weight_map: Dict[str, str] = {}
|
||||
|
||||
for shard_path in tqdm(model_shards, desc="Merging shards"):
|
||||
merged_tensors = {}
|
||||
metadata = {}
|
||||
|
||||
# Load all tensors from the shard
|
||||
if shard_path.suffix == ".safetensors":
|
||||
with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
|
||||
if hasattr(f, "metadata") and f.metadata():
|
||||
metadata = f.metadata()
|
||||
shard_tensors = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
shard_tensors = torch.load( # nosec B614: loading trusted model weights
|
||||
shard_path, map_location="cpu", weights_only=True
|
||||
)
|
||||
|
||||
total_tensors += len(shard_tensors)
|
||||
|
||||
# Step 1: Handle fused weight conversions (MoE experts) if applicable
|
||||
fused_keys: set = set()
|
||||
if weight_converters:
|
||||
shard_tensors, fused_merged, fused_keys = _fuse_and_unfuse_with_merge(
|
||||
shard_tensors,
|
||||
weight_converters,
|
||||
lora_state,
|
||||
scale,
|
||||
lora_config_dict,
|
||||
device,
|
||||
simulate_nf4=simulate_nf4,
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
)
|
||||
merged_count += fused_merged
|
||||
|
||||
# Step 2: Merge remaining (non-fused) tensors with LoRA
|
||||
# Skip keys already processed by fuse/unfuse to avoid double NF4 roundtrip
|
||||
for key, tensor in shard_tensors.items():
|
||||
if key in fused_keys:
|
||||
merged_tensors[key] = tensor.detach().cpu()
|
||||
continue
|
||||
merged_tensor, was_merged = _merge_tensor_with_lora(
|
||||
tensor,
|
||||
key,
|
||||
lora_state,
|
||||
scale,
|
||||
lora_config_dict,
|
||||
device,
|
||||
simulate_nf4=simulate_nf4,
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
)
|
||||
merged_tensors[key] = merged_tensor
|
||||
if was_merged:
|
||||
merged_count += 1
|
||||
|
||||
output_shard_path = output_path / shard_path.name
|
||||
merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
|
||||
|
||||
if safe_tensors:
|
||||
if not str(output_shard_path).endswith(".safetensors"):
|
||||
output_shard_path = output_path / (shard_path.stem + ".safetensors")
|
||||
safetensors.torch.save_file(
|
||||
merged_tensors, output_shard_path, metadata=metadata
|
||||
)
|
||||
else:
|
||||
if shard_path.suffix == ".safetensors":
|
||||
safetensors.torch.save_file(
|
||||
merged_tensors, output_shard_path, metadata=metadata
|
||||
)
|
||||
else:
|
||||
torch.save(merged_tensors, output_shard_path)
|
||||
|
||||
for tensor_key in merged_tensors:
|
||||
weight_map[tensor_key] = output_shard_path.name
|
||||
|
||||
del merged_tensors, shard_tensors
|
||||
if device != "cpu" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Regenerate weight-map index if the model was sharded
|
||||
if len(model_shards) > 1 and weight_map:
|
||||
import json as _json
|
||||
|
||||
index_name = (
|
||||
"model.safetensors.index.json"
|
||||
if safe_tensors
|
||||
else "pytorch_model.bin.index.json"
|
||||
)
|
||||
index = {
|
||||
"metadata": {"total_size": total_tensors},
|
||||
"weight_map": weight_map,
|
||||
}
|
||||
with open(output_path / index_name, "w") as f:
|
||||
_json.dump(index, f, indent=2)
|
||||
LOG.debug(f"Wrote weight-map index: {index_name}")
|
||||
|
||||
if merged_count == 0:
|
||||
LOG.warning(
|
||||
"No LoRA weights were matched to base model tensors. "
|
||||
"This may indicate a key name mismatch between the checkpoint format "
|
||||
"and the LoRA adapter. Consider using merge_method: legacy."
|
||||
)
|
||||
LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors")
|
||||
@@ -164,6 +164,16 @@ def patch_peft_target_parameters_matching():
|
||||
from peft.utils.integrations import init_empty_weights
|
||||
from peft.utils.other import _get_submodules
|
||||
|
||||
# Mapping from unfused parameter names to their fused equivalents.
|
||||
# When a model stores fused weights (e.g. gate_up_proj) but the user
|
||||
# specifies unfused names (gate_proj, up_proj), we auto-expand so the
|
||||
# fused parameter is also targeted. The original unfused names are kept
|
||||
# in the set so that models that do NOT fuse still work.
|
||||
_UNFUSED_TO_FUSED: dict[str, str] = {
|
||||
"gate_proj": "gate_up_proj",
|
||||
"up_proj": "gate_up_proj",
|
||||
}
|
||||
|
||||
def _patched_inject_parameters(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
):
|
||||
@@ -176,10 +186,43 @@ def patch_peft_target_parameters_matching():
|
||||
continue
|
||||
for target in original_targets:
|
||||
mod_path, _, param_name = target.rpartition(".")
|
||||
if (
|
||||
if not (
|
||||
module_name == mod_path or module_name.endswith("." + mod_path)
|
||||
) and hasattr(module, param_name):
|
||||
):
|
||||
continue
|
||||
|
||||
if hasattr(module, param_name):
|
||||
expanded.add(f"{module_name}.{param_name}")
|
||||
elif param_name in _UNFUSED_TO_FUSED:
|
||||
# The model uses fused weights (e.g. gate_up_proj) but the
|
||||
# user specified unfused names (gate_proj / up_proj).
|
||||
fused_name = _UNFUSED_TO_FUSED[param_name]
|
||||
if hasattr(module, fused_name):
|
||||
if fused_name not in expanded:
|
||||
LOG.warning(
|
||||
"target_parameter '%s' not found on %s, "
|
||||
"but fused equivalent '%s' exists — adding "
|
||||
"it automatically.",
|
||||
param_name,
|
||||
module_name,
|
||||
fused_name,
|
||||
)
|
||||
expanded.add(f"{module_name}.{fused_name}")
|
||||
else:
|
||||
LOG.warning(
|
||||
"target_parameter '%s' not found on %s and no "
|
||||
"fused equivalent exists either — skipping.",
|
||||
param_name,
|
||||
module_name,
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"target_parameter '%s' not found on %s — skipping. "
|
||||
"Check that the parameter name matches the model's "
|
||||
"weight names.",
|
||||
param_name,
|
||||
module_name,
|
||||
)
|
||||
|
||||
target_names_set = expanded
|
||||
|
||||
|
||||
@@ -155,6 +155,12 @@ class LoraConfig(BaseModel):
|
||||
)
|
||||
|
||||
merge_lora: bool | None = None
|
||||
merge_method: Literal["legacy", "memory_efficient"] | None = Field(
|
||||
default="memory_efficient",
|
||||
json_schema_extra={
|
||||
"description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory."
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
import json
|
||||
import math
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from axolotl.cli.merge_lora import do_merge_lora
|
||||
from axolotl.cli.utils.lora_merge import (
|
||||
_build_peft_layer_and_get_delta,
|
||||
_find_param_wrapper_lora,
|
||||
_merge_tensor_with_lora,
|
||||
find_lora_weights,
|
||||
merge_lora_sharded_efficient,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -132,6 +142,7 @@ class TestAdapterMergeUnmerge:
|
||||
"torch_dtype": torch.float32,
|
||||
"local_rank": 0,
|
||||
"output_dir": str(tmp_path),
|
||||
"merge_method": "legacy",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -167,6 +178,7 @@ class TestAdapterMergeUnmerge:
|
||||
"save_safetensors": True,
|
||||
"output_dir": str(tmp_path),
|
||||
"local_rank": 0,
|
||||
"merge_method": "legacy",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -179,3 +191,611 @@ class TestAdapterMergeUnmerge:
|
||||
do_merge_lora(cfg=cfg)
|
||||
|
||||
assert mock_load.called
|
||||
|
||||
|
||||
class TestEfficientMerge:
|
||||
"""Test suite for memory-efficient shard-by-shard LoRA merge."""
|
||||
|
||||
def _make_adapter(self, tmp_path, r=8, alpha=16, use_dora=False, use_rslora=False):
|
||||
"""Create a minimal adapter directory with config + weights."""
|
||||
adapter_dir = tmp_path / "adapter"
|
||||
adapter_dir.mkdir()
|
||||
|
||||
config = {
|
||||
"r": r,
|
||||
"lora_alpha": alpha,
|
||||
"target_modules": ["q_proj", "v_proj"],
|
||||
"task_type": "CAUSAL_LM",
|
||||
"bias": "none",
|
||||
"use_dora": use_dora,
|
||||
"use_rslora": use_rslora,
|
||||
}
|
||||
(adapter_dir / "adapter_config.json").write_text(json.dumps(config))
|
||||
return adapter_dir, config
|
||||
|
||||
def _make_base_model(self, tmp_path, hidden=32):
|
||||
"""Create a minimal base model directory with one shard."""
|
||||
model_dir = tmp_path / "base_model"
|
||||
model_dir.mkdir()
|
||||
|
||||
weights = {
|
||||
"model.layers.0.self_attn.q_proj.weight": torch.randn(hidden, hidden),
|
||||
"model.layers.0.self_attn.v_proj.weight": torch.randn(hidden, hidden),
|
||||
"model.embed_tokens.weight": torch.randn(100, hidden),
|
||||
}
|
||||
safetensors.torch.save_file(weights, model_dir / "model.safetensors")
|
||||
|
||||
# Minimal config files
|
||||
(model_dir / "config.json").write_text("{}")
|
||||
return model_dir, weights
|
||||
|
||||
def test_find_lora_weights(self):
|
||||
lora_state = {
|
||||
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
|
||||
8, 32
|
||||
),
|
||||
"base_model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
|
||||
32, 8
|
||||
),
|
||||
}
|
||||
a, b = find_lora_weights(lora_state, "layers.0.self_attn.q_proj.weight")
|
||||
assert a is not None and b is not None
|
||||
assert a.shape == (8, 32)
|
||||
|
||||
a, b = find_lora_weights(lora_state, "layers.0.self_attn.v_proj.weight")
|
||||
assert a is None and b is None
|
||||
|
||||
def test_merge_tensor_basic(self):
|
||||
hidden = 32
|
||||
r = 8
|
||||
alpha = 16
|
||||
base = torch.randn(hidden, hidden)
|
||||
lora_a = torch.randn(r, hidden)
|
||||
lora_b = torch.randn(hidden, r)
|
||||
scale = alpha / r
|
||||
|
||||
lora_state = {
|
||||
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
|
||||
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
|
||||
}
|
||||
|
||||
config = {"r": r, "lora_alpha": alpha}
|
||||
merged, was_merged = _merge_tensor_with_lora(
|
||||
base, "layer.q_proj.weight", lora_state, scale, config, "cpu"
|
||||
)
|
||||
assert was_merged
|
||||
expected = base + scale * (lora_b @ lora_a)
|
||||
assert torch.allclose(merged, expected, atol=1e-5)
|
||||
|
||||
def test_merge_tensor_rslora_scale(self):
|
||||
"""RSLoRA should use alpha/sqrt(r) as scaling factor."""
|
||||
r = 16
|
||||
alpha = 32
|
||||
standard_scale = alpha / r # 2.0
|
||||
rslora_scale = alpha / math.sqrt(r) # 8.0
|
||||
|
||||
assert rslora_scale != standard_scale
|
||||
assert abs(rslora_scale - 8.0) < 1e-6
|
||||
|
||||
def test_sharded_efficient_merge(self, tmp_path):
|
||||
"""End-to-end test of shard-by-shard merge."""
|
||||
hidden = 32
|
||||
r = 8
|
||||
alpha = 16
|
||||
|
||||
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
|
||||
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
|
||||
|
||||
# Create LoRA weights
|
||||
lora_state = {
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
|
||||
r, hidden
|
||||
),
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
|
||||
hidden, r
|
||||
),
|
||||
"base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn(
|
||||
r, hidden
|
||||
),
|
||||
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn(
|
||||
hidden, r
|
||||
),
|
||||
}
|
||||
safetensors.torch.save_file(
|
||||
lora_state, adapter_dir / "adapter_model.safetensors"
|
||||
)
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
merge_lora_sharded_efficient(
|
||||
base_model_path=model_dir,
|
||||
lora_adapter_path=adapter_dir,
|
||||
output_path=output_dir,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# Verify output exists and has merged weights
|
||||
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
|
||||
scale = alpha / r
|
||||
|
||||
q_key = "model.layers.0.self_attn.q_proj.weight"
|
||||
expected_q = base_weights[q_key] + scale * (
|
||||
lora_state[f"base_model.model.{q_key[:-7]}.lora_B.weight"]
|
||||
@ lora_state[f"base_model.model.{q_key[:-7]}.lora_A.weight"]
|
||||
)
|
||||
assert torch.allclose(merged[q_key], expected_q, atol=1e-5)
|
||||
|
||||
# Embedding should be unchanged
|
||||
assert torch.equal(
|
||||
merged["model.embed_tokens.weight"],
|
||||
base_weights["model.embed_tokens.weight"],
|
||||
)
|
||||
|
||||
def test_dora_merge(self):
|
||||
"""DoRA merge applies magnitude normalization via PEFT."""
|
||||
hidden = 32
|
||||
r = 8
|
||||
alpha = 16
|
||||
scale = alpha / r
|
||||
|
||||
base = torch.randn(hidden, hidden)
|
||||
lora_a = torch.randn(r, hidden)
|
||||
lora_b = torch.randn(hidden, r)
|
||||
magnitude = torch.randn(hidden).abs() + 0.1
|
||||
|
||||
lora_state = {
|
||||
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
|
||||
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
|
||||
"base_model.model.layer.q_proj.lora_magnitude_vector": magnitude,
|
||||
}
|
||||
|
||||
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
||||
merged, was_merged = _merge_tensor_with_lora(
|
||||
base,
|
||||
"layer.q_proj.weight",
|
||||
lora_state,
|
||||
scale,
|
||||
config,
|
||||
"cpu",
|
||||
use_dora=True,
|
||||
)
|
||||
assert was_merged
|
||||
|
||||
# The merge should differ from both base and base+delta (DoRA applies normalization)
|
||||
delta = scale * (lora_b @ lora_a)
|
||||
assert not torch.allclose(merged, base, atol=1e-3)
|
||||
assert not torch.allclose(merged, base + delta, atol=1e-3)
|
||||
|
||||
def test_fuse_unfuse_moe_merge(self):
|
||||
"""Test fuse→merge→unfuse for MoE expert weights (WeightConverter path)."""
|
||||
from axolotl.cli.utils.lora_merge import _fuse_and_unfuse_with_merge
|
||||
|
||||
hidden = 16
|
||||
intermediate = 32
|
||||
num_experts = 4
|
||||
r = 4
|
||||
alpha = 8
|
||||
scale = alpha / r
|
||||
|
||||
# Simulate checkpoint format: per-expert separate tensors
|
||||
shard_tensors = {}
|
||||
for i in range(num_experts):
|
||||
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = (
|
||||
torch.randn(intermediate, hidden)
|
||||
)
|
||||
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = (
|
||||
torch.randn(intermediate, hidden)
|
||||
)
|
||||
shard_tensors[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = (
|
||||
torch.randn(hidden, intermediate)
|
||||
)
|
||||
shard_tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn(
|
||||
hidden, hidden
|
||||
)
|
||||
|
||||
# LoRA targets the fused key (runtime format)
|
||||
lora_state = {
|
||||
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight": torch.randn(
|
||||
r, hidden
|
||||
),
|
||||
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight": torch.randn(
|
||||
intermediate * 2, r
|
||||
),
|
||||
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_A.weight": torch.randn(
|
||||
r, intermediate
|
||||
),
|
||||
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_B.weight": torch.randn(
|
||||
hidden, r
|
||||
),
|
||||
}
|
||||
|
||||
# Build converters matching qwen2_moe pattern
|
||||
from transformers.core_model_loading import (
|
||||
Concatenate,
|
||||
MergeModulelist,
|
||||
WeightConverter,
|
||||
)
|
||||
|
||||
converters = [
|
||||
WeightConverter(
|
||||
source_patterns=[
|
||||
"mlp.experts.*.gate_proj.weight",
|
||||
"mlp.experts.*.up_proj.weight",
|
||||
],
|
||||
target_patterns="mlp.experts.gate_up_proj",
|
||||
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
|
||||
),
|
||||
WeightConverter(
|
||||
source_patterns="mlp.experts.*.down_proj.weight",
|
||||
target_patterns="mlp.experts.down_proj",
|
||||
operations=[MergeModulelist(dim=0)],
|
||||
),
|
||||
]
|
||||
|
||||
config = {"r": r, "lora_alpha": alpha}
|
||||
result, merged_count, processed_keys = _fuse_and_unfuse_with_merge(
|
||||
shard_tensors, converters, lora_state, scale, config, "cpu"
|
||||
)
|
||||
|
||||
# Should have merged 2 LoRA targets (gate_up_proj and down_proj)
|
||||
assert merged_count == 2
|
||||
|
||||
# Processed keys include original per-expert keys (removed) + fused keys (added)
|
||||
assert len(processed_keys) > 0
|
||||
|
||||
# Output should be in fused format (runtime keys)
|
||||
assert "model.layers.0.mlp.experts.gate_up_proj" in result
|
||||
assert "model.layers.0.mlp.experts.down_proj" in result
|
||||
|
||||
# Per-expert keys should be removed
|
||||
for i in range(num_experts):
|
||||
assert f"model.layers.0.mlp.experts.{i}.gate_proj.weight" not in result
|
||||
|
||||
# Non-expert tensor should be passed through
|
||||
assert "model.layers.0.self_attn.q_proj.weight" in result
|
||||
|
||||
# Verify fused tensors are 3D (stacked experts)
|
||||
gate_up = result["model.layers.0.mlp.experts.gate_up_proj"]
|
||||
assert gate_up.ndim == 3
|
||||
assert gate_up.shape[0] == num_experts # [num_experts, intermediate*2, hidden]
|
||||
|
||||
# Verify the fused LoRA delta was applied correctly
|
||||
# Reconstruct the fused base (stack per-expert, concat gate+up)
|
||||
gate_stack = torch.stack(
|
||||
[
|
||||
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"]
|
||||
for i in range(num_experts)
|
||||
]
|
||||
)
|
||||
up_stack = torch.stack(
|
||||
[
|
||||
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"]
|
||||
for i in range(num_experts)
|
||||
]
|
||||
)
|
||||
base_fused = torch.cat([gate_stack, up_stack], dim=1)
|
||||
lora_a = lora_state[
|
||||
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight"
|
||||
]
|
||||
lora_b = lora_state[
|
||||
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight"
|
||||
]
|
||||
expected_fused = base_fused + scale * (lora_b @ lora_a)
|
||||
assert torch.allclose(gate_up, expected_fused, atol=1e-5)
|
||||
|
||||
def test_param_wrapper_merge_math(self):
|
||||
"""ParamWrapper merge via PEFT's get_delta_weight matches manual einsum."""
|
||||
num_experts = 4
|
||||
r = 2
|
||||
in_features = 8
|
||||
out_features = 4
|
||||
alpha = 4
|
||||
|
||||
base = torch.randn(num_experts, in_features, out_features)
|
||||
lora_a = torch.randn(r * num_experts, in_features)
|
||||
lora_b = torch.randn(out_features, r * num_experts)
|
||||
|
||||
config = {"r": r, "lora_alpha": alpha}
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a, lora_b, config, base, is_param_wrapper=True
|
||||
)
|
||||
assert delta.shape == base.shape
|
||||
|
||||
merged = base + delta
|
||||
|
||||
# Verify against manual einsum
|
||||
scale = alpha / r
|
||||
wa = lora_a.reshape(num_experts, r, in_features)
|
||||
wb = lora_b.reshape(out_features, r, num_experts)
|
||||
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
|
||||
for e in range(num_experts):
|
||||
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
|
||||
f"Expert {e} mismatch"
|
||||
)
|
||||
|
||||
def test_param_wrapper_nesting_dim_filter(self):
|
||||
"""_find_param_wrapper_lora skips wrong-dimension LoRA at outer level."""
|
||||
num_experts = 4
|
||||
r = 2
|
||||
|
||||
# Outer LoRA (gate_up_proj): A=[r*E, 8], B=[16, r*E]
|
||||
# Inner LoRA (down_proj via base_layer): A=[r*E, 16], B=[8, r*E]
|
||||
lora_state = {
|
||||
"base_model.model.mod.experts.lora_A.weight": torch.randn(
|
||||
r * num_experts, 8
|
||||
),
|
||||
"base_model.model.mod.experts.lora_B.weight": torch.randn(
|
||||
16, r * num_experts
|
||||
),
|
||||
"base_model.model.mod.experts.base_layer.lora_A.weight": torch.randn(
|
||||
r * num_experts, 16
|
||||
),
|
||||
"base_model.model.mod.experts.base_layer.lora_B.weight": torch.randn(
|
||||
8, r * num_experts
|
||||
),
|
||||
}
|
||||
|
||||
# gate_up_proj shape [4, 8, 16] — should match outer LoRA
|
||||
a, b, name = _find_param_wrapper_lora(
|
||||
lora_state, "mod.experts.gate_up_proj", tensor_shape=(4, 8, 16)
|
||||
)
|
||||
assert a is not None and name == "gate_up_proj"
|
||||
assert a.shape == (r * num_experts, 8) # outer
|
||||
|
||||
# down_proj shape [4, 16, 8] — outer dims don't match, should find inner
|
||||
a, b, name = _find_param_wrapper_lora(
|
||||
lora_state, "mod.experts.down_proj", tensor_shape=(4, 16, 8)
|
||||
)
|
||||
assert a is not None and name == "down_proj"
|
||||
assert a.shape == (r * num_experts, 16) # inner (base_layer)
|
||||
|
||||
# shape that matches neither — should return None
|
||||
a, b, name = _find_param_wrapper_lora(
|
||||
lora_state, "mod.experts.other", tensor_shape=(4, 99, 99)
|
||||
)
|
||||
assert a is None
|
||||
|
||||
def test_find_lora_weights_with_renamings(self):
|
||||
"""Weight renamings let checkpoint keys match LoRA keys."""
|
||||
lora_state = {
|
||||
"base_model.model.layers.0.mlp.fc1.lora_A.weight": torch.randn(8, 32),
|
||||
"base_model.model.layers.0.mlp.fc1.lora_B.weight": torch.randn(32, 8),
|
||||
}
|
||||
# Direct lookup fails (checkpoint has "ff0", LoRA has "fc1")
|
||||
a, b = find_lora_weights(lora_state, "layers.0.mlp.ff0.weight")
|
||||
assert a is None
|
||||
|
||||
# With renaming ff0 → fc1, it should match
|
||||
a, b = find_lora_weights(
|
||||
lora_state, "layers.0.mlp.ff0.weight", weight_renamings={"ff0": "fc1"}
|
||||
)
|
||||
assert a is not None
|
||||
assert a.shape == (8, 32)
|
||||
|
||||
def test_unmatched_tensors_pass_through(self):
|
||||
"""Tensors with no matching LoRA are returned unchanged."""
|
||||
lora_state = {
|
||||
"base_model.model.layer.q_proj.lora_A.weight": torch.randn(8, 32),
|
||||
"base_model.model.layer.q_proj.lora_B.weight": torch.randn(32, 8),
|
||||
}
|
||||
|
||||
# 1D tensor (layernorm) — never matched
|
||||
ln = torch.randn(32)
|
||||
merged, was_merged = _merge_tensor_with_lora(
|
||||
ln, "layer.norm.weight", lora_state, 2.0, {}, "cpu"
|
||||
)
|
||||
assert not was_merged
|
||||
assert torch.equal(merged, ln)
|
||||
|
||||
# 2D tensor with no matching key
|
||||
unrelated = torch.randn(64, 32)
|
||||
merged, was_merged = _merge_tensor_with_lora(
|
||||
unrelated, "layer.other_proj.weight", lora_state, 2.0, {}, "cpu"
|
||||
)
|
||||
assert not was_merged
|
||||
assert torch.equal(merged, unrelated)
|
||||
|
||||
def test_fan_in_fan_out_transpose(self):
|
||||
"""fan_in_fan_out config transposes the LoRA delta."""
|
||||
hidden = 16
|
||||
r = 4
|
||||
alpha = 4 # scale = 1.0
|
||||
|
||||
base = torch.randn(hidden, hidden)
|
||||
lora_a = torch.randn(r, hidden)
|
||||
lora_b = torch.randn(hidden, r)
|
||||
|
||||
lora_state = {
|
||||
"base_model.model.layer.proj.lora_A.weight": lora_a,
|
||||
"base_model.model.layer.proj.lora_B.weight": lora_b,
|
||||
}
|
||||
|
||||
config_normal = {"r": r, "lora_alpha": alpha}
|
||||
config_fif = {"r": r, "lora_alpha": alpha, "fan_in_fan_out": True}
|
||||
|
||||
merged_normal, _ = _merge_tensor_with_lora(
|
||||
base, "layer.proj.weight", lora_state, 1.0, config_normal, "cpu"
|
||||
)
|
||||
merged_fif, _ = _merge_tensor_with_lora(
|
||||
base, "layer.proj.weight", lora_state, 1.0, config_fif, "cpu"
|
||||
)
|
||||
|
||||
delta = (alpha / r) * (lora_b @ lora_a)
|
||||
assert torch.allclose(merged_normal, base + delta, atol=1e-5)
|
||||
assert torch.allclose(merged_fif, base + delta.T, atol=1e-5)
|
||||
assert not torch.allclose(merged_normal, merged_fif, atol=1e-5)
|
||||
|
||||
def test_rslora_end_to_end(self, tmp_path):
|
||||
"""RSLoRA adapter uses alpha/sqrt(r) scaling in sharded merge."""
|
||||
hidden = 16
|
||||
r = 16
|
||||
alpha = 32
|
||||
|
||||
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
|
||||
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_rslora=True)
|
||||
|
||||
lora_a = torch.randn(r, hidden)
|
||||
lora_b = torch.randn(hidden, r)
|
||||
lora_state = {
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
|
||||
}
|
||||
safetensors.torch.save_file(
|
||||
lora_state, adapter_dir / "adapter_model.safetensors"
|
||||
)
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
merge_lora_sharded_efficient(
|
||||
base_model_path=model_dir,
|
||||
lora_adapter_path=adapter_dir,
|
||||
output_path=output_dir,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
|
||||
rslora_scale = alpha / math.sqrt(r) # 8.0, not 2.0
|
||||
q_key = "model.layers.0.self_attn.q_proj.weight"
|
||||
expected = base_weights[q_key] + rslora_scale * (lora_b @ lora_a)
|
||||
assert torch.allclose(merged[q_key], expected, atol=1e-5)
|
||||
|
||||
# Confirm it differs from standard scale
|
||||
wrong_scale = alpha / r # 2.0
|
||||
wrong_expected = base_weights[q_key] + wrong_scale * (lora_b @ lora_a)
|
||||
assert not torch.allclose(merged[q_key], wrong_expected, atol=1e-3)
|
||||
|
||||
def test_multi_shard_index_json(self, tmp_path):
|
||||
"""Multi-shard merge generates a correct weight-map index."""
|
||||
hidden = 16
|
||||
r = 4
|
||||
alpha = 8
|
||||
|
||||
model_dir = tmp_path / "base_model"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "config.json").write_text("{}")
|
||||
|
||||
# Create 2 shards
|
||||
shard1 = {"model.layers.0.weight": torch.randn(hidden, hidden)}
|
||||
shard2 = {"model.layers.1.weight": torch.randn(hidden, hidden)}
|
||||
safetensors.torch.save_file(
|
||||
shard1, model_dir / "model-00001-of-00002.safetensors"
|
||||
)
|
||||
safetensors.torch.save_file(
|
||||
shard2, model_dir / "model-00002-of-00002.safetensors"
|
||||
)
|
||||
|
||||
# Write a base model index (will be skipped by copy_non_model_files)
|
||||
base_index = {
|
||||
"metadata": {},
|
||||
"weight_map": {
|
||||
"model.layers.0.weight": "model-00001-of-00002.safetensors",
|
||||
"model.layers.1.weight": "model-00002-of-00002.safetensors",
|
||||
},
|
||||
}
|
||||
(model_dir / "model.safetensors.index.json").write_text(json.dumps(base_index))
|
||||
|
||||
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
|
||||
safetensors.torch.save_file({}, adapter_dir / "adapter_model.safetensors")
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
merge_lora_sharded_efficient(
|
||||
base_model_path=model_dir,
|
||||
lora_adapter_path=adapter_dir,
|
||||
output_path=output_dir,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# Verify index was generated
|
||||
index_path = output_dir / "model.safetensors.index.json"
|
||||
assert index_path.exists()
|
||||
with open(index_path) as f:
|
||||
idx = json.load(f)
|
||||
|
||||
assert "weight_map" in idx
|
||||
assert len(idx["weight_map"]) == 2
|
||||
# Each key should map to a shard that exists
|
||||
for _key, shard_name in idx["weight_map"].items():
|
||||
assert (output_dir / shard_name).exists(), f"Missing shard: {shard_name}"
|
||||
|
||||
def test_dora_end_to_end(self, tmp_path):
|
||||
"""DoRA merge through the full sharded merge pipeline."""
|
||||
hidden = 16
|
||||
r = 4
|
||||
alpha = 8
|
||||
|
||||
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
|
||||
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_dora=True)
|
||||
|
||||
lora_a = torch.randn(r, hidden)
|
||||
lora_b = torch.randn(hidden, r)
|
||||
magnitude = torch.randn(hidden).abs() + 0.1
|
||||
lora_state = {
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector": magnitude,
|
||||
}
|
||||
safetensors.torch.save_file(
|
||||
lora_state, adapter_dir / "adapter_model.safetensors"
|
||||
)
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
merge_lora_sharded_efficient(
|
||||
base_model_path=model_dir,
|
||||
lora_adapter_path=adapter_dir,
|
||||
output_path=output_dir,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
|
||||
q_key = "model.layers.0.self_attn.q_proj.weight"
|
||||
|
||||
# Use PEFT's own get_delta_weight as the reference
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a,
|
||||
lora_b,
|
||||
{"r": r, "lora_alpha": alpha, "use_dora": True},
|
||||
base_weights[q_key],
|
||||
magnitude=magnitude,
|
||||
)
|
||||
expected = base_weights[q_key] + delta
|
||||
assert torch.allclose(merged[q_key], expected, atol=1e-5)
|
||||
|
||||
# Verify it differs from standard (non-DoRA) merge
|
||||
standard_delta = _build_peft_layer_and_get_delta(
|
||||
lora_a,
|
||||
lora_b,
|
||||
{"r": r, "lora_alpha": alpha},
|
||||
base_weights[q_key],
|
||||
)
|
||||
assert not torch.allclose(delta, standard_delta, atol=1e-3)
|
||||
|
||||
# v_proj has no LoRA weights — should be unchanged
|
||||
v_key = "model.layers.0.self_attn.v_proj.weight"
|
||||
assert torch.equal(merged[v_key], base_weights[v_key]), (
|
||||
"v_proj should be unchanged (no LoRA weights for it)"
|
||||
)
|
||||
|
||||
def test_dora_missing_magnitude_falls_back(self):
|
||||
"""DoRA without magnitude vector falls back to standard LoRA merge."""
|
||||
hidden = 16
|
||||
r = 4
|
||||
alpha = 8
|
||||
scale = alpha / r
|
||||
|
||||
base = torch.randn(hidden, hidden)
|
||||
lora_a = torch.randn(r, hidden)
|
||||
lora_b = torch.randn(hidden, r)
|
||||
|
||||
# No magnitude vector in lora_state
|
||||
lora_state = {
|
||||
"base_model.model.layer.proj.lora_A.weight": lora_a,
|
||||
"base_model.model.layer.proj.lora_B.weight": lora_b,
|
||||
}
|
||||
|
||||
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
||||
merged, was_merged = _merge_tensor_with_lora(
|
||||
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
|
||||
)
|
||||
assert was_merged
|
||||
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
|
||||
# which produces a result different from plain W + scale * B @ A.
|
||||
# Just verify it was merged (not unchanged).
|
||||
assert not torch.equal(merged, base)
|
||||
|
||||
Reference in New Issue
Block a user