feat:merge-lora iterate through bins without loading (#3095)

* merge_method added

* merge_efficient core implement

* Update src/axolotl/cli/merge_lora.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/lora_merge_efficient.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* standard to leagcy + rstrip + try/except for do_merge_lora_efficient(cfg=cfg)

* fix: 'dict' object has no attribute 'lora_alpha'

* into -> debug

* lint

* lint2

* moved everythign to cpu + peformance improvments

* lint

* Update src/axolotl/cli/merge_lora.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/cli/merge_lora.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* string handeling +  try except remove

* merge_method -> merge_lora_methods

* remove duplicate cal + safetensor + move to lora_merge.py

* lint

* handle quant-dequant, handle experts

* fix parameter merging and prefer peft's native merge logic per module

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
VED
2026-03-25 18:11:32 +05:30
committed by GitHub
parent ff0f67c730
commit b55706b9f6
5 changed files with 1735 additions and 5 deletions

View File

@@ -4,9 +4,11 @@ from pathlib import Path
from typing import Union
import fire
import torch
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -17,12 +19,26 @@ LOG = get_logger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Merges LoRA adapters with base model using either memory-efficient or legacy approach.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
merge_method = str(getattr(cfg, "merge_method", "memory_efficient"))
if merge_method == "legacy":
LOG.debug("Using legacy LoRA merging method...")
_do_merge_lora_legacy(cfg=cfg)
else:
LOG.debug("Using memory-efficient LoRA merging method...")
_do_merge_lora_efficient(cfg=cfg)
def _do_merge_lora_legacy(*, cfg: DictDefault) -> None:
"""
Legacy LoRA merging using merge_and_unload.
Loads the full model into memory before merging.
"""
LOG.debug("Using legacy LoRA merging method...")
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
LOG.info("Running merge of LoRA with base model...")
@@ -52,6 +68,58 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
"""
Memory-efficient LoRA merging using shard-by-shard processing.
Does not load the full model into memory.
Supports standard LoRA, RSLoRA, and DoRA. Unsupported methods (AdaLoRA, VeRA)
will raise NotImplementedError — use legacy method for those.
"""
LOG.debug("Using memory-efficient LoRA merging method...")
output_path = Path(cfg.output_dir) / "merged"
safe_tensors = getattr(cfg, "save_safetensors", True)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Detect NF4 quantization from config to simulate QLoRA training dynamics.
# Check both current and original (pre-override) config values since do_cli
# forces load_in_4bit=False for the legacy path.
simulate_nf4 = bool(
getattr(cfg, "load_in_4bit", False)
or getattr(cfg, "_original_load_in_4bit", False)
or getattr(cfg, "adapter", None) == "qlora"
or getattr(cfg, "_original_adapter", None) == "qlora"
)
bnb_config_kwargs = getattr(cfg, "bnb_config_kwargs", None) or {}
nf4_blocksize = bnb_config_kwargs.get("blocksize", None)
nf4_double_quant = bnb_config_kwargs.get(
"bnb_4bit_use_double_quant",
getattr(cfg, "bnb_4bit_use_double_quant", True),
)
# Detect MoE expert quantization
simulate_nf4_experts = bool(
getattr(cfg, "quantize_moe_experts", False)
or getattr(cfg, "_original_quantize_moe_experts", False)
)
merge_lora_sharded_efficient(
base_model_path=cfg.base_model,
lora_adapter_path=cfg.lora_model_dir,
output_path=output_path,
safe_tensors=safe_tensors,
device=device,
simulate_nf4=simulate_nf4,
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
)
LOG.debug("Memory-efficient LoRA merge completed successfully!")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
@@ -66,6 +134,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
ValueError: If target directory for LoRA merged model does not exist.
"""
# Pre-load config to detect original quantization settings before overrides
raw_cfg = load_cfg(config, **kwargs)
original_load_in_4bit = getattr(raw_cfg, "load_in_4bit", False)
original_adapter = getattr(raw_cfg, "adapter", None)
original_quantize_moe_experts = getattr(raw_cfg, "quantize_moe_experts", False)
parsed_cfg = load_cfg(
config,
merge_lora=True,
@@ -80,11 +154,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
**kwargs,
)
# Stash original quantization settings for NF4 simulation in efficient merge
parsed_cfg._original_load_in_4bit = original_load_in_4bit
parsed_cfg._original_adapter = original_adapter
parsed_cfg._original_quantize_moe_experts = original_quantize_moe_experts
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"
)
do_merge_lora(cfg=parsed_cfg)

View File

@@ -0,0 +1,982 @@
import gc
import math
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Union
import safetensors
import safetensors.torch
import torch
from huggingface_hub import snapshot_download
from peft import LoraConfig
from tqdm import tqdm
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _simulate_nf4_roundtrip(
tensor: torch.Tensor,
blocksize: Optional[int] = None,
compress_statistics: bool = True,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""
Simulate NF4 quantization roundtrip to match QLoRA training dynamics.
During QLoRA training, base weights are quantized to NF4 and dequantized on-the-fly
for each forward pass. The LoRA adapters learn to compensate for the quantization
noise in the dequantized weights. To match this at merge time, we apply the same
quantize → dequantize roundtrip so the merged result reflects what the model saw
during training.
Args:
tensor: Base model weight tensor (fp16/bf16/fp32)
blocksize: NF4 quantization block size (default: bitsandbytes default)
compress_statistics: Whether to use double quantization
device: Device for quantization computation. bitsandbytes requires a
CUDA device; defaults to "cuda" when available.
Returns:
Tensor after NF4 quantize → dequantize roundtrip, in original dtype
"""
import bitsandbytes.functional as bnb_F
quant_device: torch.device
if device is None:
quant_device = torch.device("cuda")
elif isinstance(device, str):
quant_device = torch.device(device)
else:
quant_device = device
if quant_device.type == "cuda" and not torch.cuda.is_available():
raise RuntimeError(
"NF4 simulation requires CUDA but no GPU is available. "
"Either run on a machine with a GPU or disable NF4 simulation."
)
original_dtype = tensor.dtype
original_shape = tensor.shape
# bitsandbytes requires float32 input for quantization and contiguous+CUDA tensor
flat = tensor.reshape(-1).to(torch.float32).contiguous().to(quant_device)
quant_kwargs = {
"quant_type": "nf4",
"compress_statistics": compress_statistics,
}
if blocksize is not None:
quant_kwargs["blocksize"] = blocksize
quantized, quant_state = bnb_F.quantize_4bit(flat, **quant_kwargs)
dequantized = bnb_F.dequantize_4bit(quantized, quant_state, quant_type="nf4")
return dequantized.reshape(original_shape).to(original_dtype).cpu()
def find_lora_weights(
lora_state: Dict[str, torch.Tensor],
key: str,
weight_renamings: Optional[Dict[str, str]] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Find corresponding LoRA A and B weights for a given key.
Also tries keys after applying weight renamings (from transformers v5
conversion mappings) in case the checkpoint key names differ from the
runtime model key names used by the LoRA adapter.
"""
import re
clean_key = key[:-7] if key.endswith(".weight") else key
# Try the direct key first
a_key = f"base_model.model.{clean_key}.lora_A.weight"
b_key = f"base_model.model.{clean_key}.lora_B.weight"
lora_a = lora_state.get(a_key)
lora_b = lora_state.get(b_key)
if lora_a is not None and lora_b is not None:
return lora_a, lora_b
# Try renamed keys (checkpoint format → runtime format)
if weight_renamings:
for src_pattern, tgt_pattern in weight_renamings.items():
renamed_key = re.sub(src_pattern, tgt_pattern, clean_key)
if renamed_key != clean_key:
a_key = f"base_model.model.{renamed_key}.lora_A.weight"
b_key = f"base_model.model.{renamed_key}.lora_B.weight"
lora_a = lora_state.get(a_key)
lora_b = lora_state.get(b_key)
if lora_a is not None and lora_b is not None:
return lora_a, lora_b
return None, None
def _find_param_wrapper_lora(
lora_state: Dict[str, torch.Tensor],
key: str,
tensor_shape: Optional[tuple] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[str]]:
"""
Find LoRA weights from a ParamWrapper (lora_target_parameters) that targets
a parent module containing this weight as a sub-parameter.
For example, base weight key 'model.layers.0.mlp.experts.down_proj' may have
LoRA at 'base_model.model.model.layers.0.mlp.experts.lora_A.weight' (targeting
the 'experts' module with 'down_proj' as the parameter_name).
When tensor_shape is provided, validates that the LoRA dimensions match the
target tensor (important when multiple ParamWrappers are nested and each
nesting level has different LoRA dimensions).
Returns (lora_A, lora_B, parameter_name) or (None, None, None).
"""
clean_key = key[:-7] if key.endswith(".weight") else key
# Strip trailing parameter name to get the parent module path
# e.g., "model.layers.0.mlp.experts.down_proj" → parent="model.layers.0.mlp.experts", param="down_proj"
parts = clean_key.rsplit(".", 1)
if len(parts) != 2:
return None, None, None
parent_key, param_name = parts
# PEFT's ParamWrapper nesting: when multiple parameters are targeted on
# the same module, it nests wrappers. The outer wrapper's LoRA is at
# parent.lora_A/B and inner wrappers use parent.base_layer.lora_A/B,
# parent.base_layer.base_layer.lora_A/B, etc.
prefixes_to_try = [
f"base_model.model.{parent_key}",
]
# Walk up .base_layer nesting levels (typically 1-2 deep)
for depth in range(1, 4):
bl = ".base_layer" * depth
prefixes_to_try.append(f"base_model.model.{parent_key}{bl}")
for prefix in prefixes_to_try:
a_key = f"{prefix}.lora_A.weight"
b_key = f"{prefix}.lora_B.weight"
lora_a = lora_state.get(a_key)
lora_b = lora_state.get(b_key)
if lora_a is None or lora_b is None:
continue
# When tensor_shape is given, verify dimensions match before returning.
# This prevents returning a mismatched LoRA from a different nesting level.
if tensor_shape is not None and len(tensor_shape) >= 3:
num_experts = tensor_shape[0]
if not (
lora_a.shape[0] == lora_b.shape[1]
and lora_a.shape[0] % num_experts == 0
and lora_a.shape[1] == tensor_shape[1]
and lora_b.shape[0] == tensor_shape[2]
):
continue # Dimensions don't match, try next nesting level
return lora_a, lora_b, param_name
return None, None, None
def _build_peft_layer_and_get_delta(
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_config_dict: Dict,
base_tensor: torch.Tensor,
adapter_name: str = "default",
is_param_wrapper: bool = False,
magnitude: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Use PEFT's own layer classes to compute the LoRA delta weight.
Instead of re-implementing the merge math for every LoRA variant, this
constructs a lightweight PEFT layer, loads the A/B weights, and calls
``get_delta_weight`` (or ``merge`` for DoRA) which handles standard LoRA,
RSLoRA, DoRA, and ParamWrapper (expert-blocked) LoRA.
Returns the delta tensor (same shape as base_tensor).
"""
import warnings
import torch.nn as nn
r_total = lora_a.shape[0]
in_features = lora_a.shape[1]
out_features = lora_b.shape[0]
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
use_rslora = bool(lora_config_dict.get("use_rslora", False))
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
if is_param_wrapper:
from peft.tuners.lora.layer import ParamWrapper
num_experts = base_tensor.shape[0]
r = r_total // num_experts
class _FakeModule(nn.Module):
pass
fake = _FakeModule()
fake.register_parameter(
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
layer = ParamWrapper(
fake,
adapter_name=adapter_name,
parameter_name="weight",
r=r,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
return layer.get_delta_weight(adapter_name)
else:
from peft.tuners.lora.layer import Linear as LoraLinear
base_layer = nn.Linear(in_features, out_features, bias=False)
base_layer.weight.data = base_tensor.clone()
fan_in_fan_out = bool(
lora_config_dict.get("fan_in_fan_out", False)
or lora_config_dict.get("lora_fan_in_fan_out", False)
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
r=r_total,
lora_alpha=lora_alpha,
fan_in_fan_out=fan_in_fan_out,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
if use_dora:
# DoRA merges magnitude normalization into the weight directly.
# Use PEFT's merge() which handles DoRA internally, then
# compute the delta as merged_weight - original_weight.
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
return base_layer.weight.data - base_tensor
return layer.get_delta_weight(adapter_name)
def get_model_shards(model_path: Path) -> list[Path]:
"""Find all model shards in the given path."""
shards: list[Path] = []
patterns = ["model*.safetensors", "pytorch_model*.bin"]
for pattern in patterns:
shards.extend(model_path.glob(pattern))
if shards:
break
return sorted(shards)
def copy_non_model_files(
input_path: Path, output_path: Path, model_shards: list[Path]
) -> None:
"""
Copy all non-model files to the output directory.
Args:
input_path: Source directory
output_path: Destination directory
model_shards: List of model shard files to skip
"""
LOG.info("Copying non-model files to output directory...")
shard_names = {shard.name for shard in model_shards}
for filepath in input_path.glob("*"):
if filepath.is_dir():
continue
if filepath.name in shard_names:
continue
if (
filepath.name.startswith("model") and filepath.suffix == ".safetensors"
) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"):
continue
if filepath.suffix == ".gguf":
continue
# Skip weight-map index files — they reference shard filenames that may
# change during the merge (e.g. .bin → .safetensors). A correct index
# is regenerated after all shards have been written.
if filepath.name.endswith(".index.json"):
continue
LOG.debug(f"Copying {filepath.name} to output")
shutil.copy2(filepath, output_path)
def _find_dora_magnitude(
lora_state: Dict[str, torch.Tensor],
key: str,
weight_renamings: Optional[Dict[str, str]] = None,
) -> Optional[torch.Tensor]:
"""
Find DoRA magnitude vector for a given key.
"""
import re
clean_key = key[:-7] if key.endswith(".weight") else key
mag_key = f"base_model.model.{clean_key}.lora_magnitude_vector"
result = lora_state.get(mag_key)
if result is not None:
return result
if weight_renamings:
for src_pattern, tgt_pattern in weight_renamings.items():
renamed_key = re.sub(src_pattern, tgt_pattern, clean_key)
if renamed_key != clean_key:
mag_key = f"base_model.model.{renamed_key}.lora_magnitude_vector"
result = lora_state.get(mag_key)
if result is not None:
return result
return None
def _should_nf4_roundtrip(
key: str,
tensor: torch.Tensor,
simulate_nf4: bool,
simulate_nf4_experts: bool,
) -> bool:
"""Determine if a tensor should undergo NF4 quantization roundtrip."""
if tensor.ndim < 2:
return False
if simulate_nf4:
return True
if simulate_nf4_experts and tensor.ndim >= 3 and "expert" in key.lower():
return True
return False
def _merge_tensor_with_lora(
tensor: torch.Tensor,
key: str,
lora_state: Dict[str, torch.Tensor],
scale: float,
lora_config_dict: Dict,
device: str,
simulate_nf4: bool = False,
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
) -> tuple[torch.Tensor, bool]:
"""
Helper function to merge a single tensor with its corresponding LoRA weights.
Args:
tensor: Base model tensor
key: Tensor key/name
lora_state: Dictionary containing LoRA weights
scale: LoRA scaling factor (alpha/r)
lora_config_dict: LoRA configuration dictionary
device: Device to perform computations on
simulate_nf4: Whether to simulate NF4 quantization roundtrip for all weights
simulate_nf4_experts: Whether to simulate NF4 roundtrip for MoE expert tensors only
nf4_blocksize: Block size for NF4 quantization
nf4_double_quant: Whether to use double quantization
use_dora: Whether to apply DoRA (Weight-Decomposed LoRA) merging
weight_renamings: Optional key renamings from transformers conversion mapping
Returns:
Tuple of (merged tensor, whether LoRA was applied)
"""
lora_a, lora_b = find_lora_weights(lora_state, key, weight_renamings)
do_nf4 = _should_nf4_roundtrip(key, tensor, simulate_nf4, simulate_nf4_experts)
if lora_a is not None and lora_b is not None:
LOG.debug(f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}")
original_dtype = tensor.dtype
# Simulate NF4 quantization roundtrip to match QLoRA training dynamics
if do_nf4:
tensor = _simulate_nf4_roundtrip(
tensor,
blocksize=nf4_blocksize,
compress_statistics=nf4_double_quant,
device=device,
)
magnitude = (
_find_dora_magnitude(lora_state, key, weight_renamings)
if use_dora
else None
)
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
)
merged_tensor = (
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
.to(original_dtype)
.detach()
.cpu()
)
return merged_tensor, True
else:
# Try ParamWrapper LoRA (lora_target_parameters) — the LoRA targets a
# parent module and this weight is a sub-parameter of that module.
if tensor.ndim >= 3:
pw_a, pw_b, param_name = _find_param_wrapper_lora(
lora_state, key, tensor_shape=tuple(tensor.shape)
)
if pw_a is not None and pw_b is not None:
LOG.debug(
f"Merging ParamWrapper LoRA for {key} "
f"(param={param_name}): {pw_a.shape}, {pw_b.shape}"
)
if do_nf4:
tensor = _simulate_nf4_roundtrip(
tensor,
blocksize=nf4_blocksize,
compress_statistics=nf4_double_quant,
device=device,
)
original_dtype = tensor.dtype
delta = _build_peft_layer_and_get_delta(
pw_a.to(device),
pw_b.to(device),
lora_config_dict,
tensor.to(device),
is_param_wrapper=True,
)
merged = (
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
.to(original_dtype)
.detach()
.cpu()
)
return merged, True
if do_nf4:
tensor = _simulate_nf4_roundtrip(
tensor,
blocksize=nf4_blocksize,
compress_statistics=nf4_double_quant,
device=device,
)
return tensor.detach().cpu(), False
def _get_conversion_info(base_model_path: Path) -> tuple[Dict[str, str], list]:
"""
Load the model's config.json and check if transformers has WeightRenaming
or WeightConverter mappings for this model type.
Returns:
- dict of {source_pattern: target_pattern} for simple renamings
- list of WeightConverter objects for fuse/unfuse operations
"""
import json as _json
config_path = base_model_path / "config.json"
if not config_path.exists():
return {}, []
try:
with open(config_path) as f:
model_config = _json.load(f)
except (OSError, _json.JSONDecodeError):
return {}, []
model_type = model_config.get("model_type")
if not model_type:
return {}, []
try:
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
from transformers.core_model_loading import WeightConverter, WeightRenaming
except ImportError:
return {}, []
conversions = get_checkpoint_conversion_mapping(model_type)
if not conversions:
return {}, []
renamings = {}
weight_converters = []
for conv in conversions:
if isinstance(conv, WeightRenaming):
# WeightRenaming stores patterns as lists internally
src_list = (
conv.source_patterns
if isinstance(conv.source_patterns, list)
else [conv.source_patterns]
)
tgt_list = (
conv.target_patterns
if isinstance(conv.target_patterns, list)
else [conv.target_patterns]
)
if len(src_list) == 1 and len(tgt_list) == 1:
renamings[src_list[0]] = tgt_list[0]
elif isinstance(conv, WeightConverter):
weight_converters.append(conv)
return renamings, weight_converters
def _fuse_and_unfuse_with_merge(
shard_tensors: Dict[str, torch.Tensor],
weight_converters: list,
lora_state: Dict[str, torch.Tensor],
scale: float,
lora_config_dict: Dict,
device: str,
simulate_nf4: bool = False,
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
) -> tuple[Dict[str, torch.Tensor], int, set]:
"""
For tensors matching WeightConverter patterns (MoE expert weights):
1. Fuse checkpoint-format tensors into runtime-format (e.g., per-expert → fused 3D)
2. Apply NF4 roundtrip + LoRA merge on the fused tensor
3. Unfuse back to checkpoint format for saving
Returns:
- Updated tensor dict
- Count of merged LoRA targets
- Set of keys that were processed (fused/merged/unfused) and should be
skipped by the per-tensor merge pass to avoid double NF4 roundtrip
"""
import re
from transformers.core_model_loading import Concatenate, MergeModulelist
result = dict(shard_tensors) # Start with all tensors
merged_count = 0
processed_keys: set = set() # Keys that were fuse/unfuse processed
for converter in weight_converters:
src_patterns = (
converter.source_patterns
if isinstance(converter.source_patterns, list)
else [converter.source_patterns]
)
tgt_patterns = (
converter.target_patterns
if isinstance(converter.target_patterns, list)
else [converter.target_patterns]
)
# Build regex for each source pattern
pattern_regexes = []
for pat in src_patterns:
regex_str = re.escape(pat).replace(r"\.\*\.", r"\.(\d+)\.")
regex_str = (
regex_str.rstrip(r"\$") if regex_str.endswith(r"\$") else regex_str
)
pattern_regexes.append(re.compile(r"(.*\.)?" + regex_str + "$"))
# Group matching keys by layer prefix and source pattern
# {layer_prefix: {pat_idx: {expert_idx: (key, tensor)}}}
layer_groups: Dict[str, Dict[int, Dict[int, tuple[str, torch.Tensor]]]] = {}
for key in list(result.keys()):
for pat_idx, pat_regex in enumerate(pattern_regexes):
match = pat_regex.match(key)
if match:
prefix = match.group(1) or ""
# Extract expert index from the matched portion
remaining = key[len(prefix) :]
expert_match = re.search(r"\.(\d+)\.", remaining)
expert_idx = int(expert_match.group(1)) if expert_match else 0
layer_groups.setdefault(prefix, {}).setdefault(pat_idx, {})[
expert_idx
] = (key, result[key])
break
# Process each layer group
for prefix, pat_groups in layer_groups.items():
# Check we have all source patterns for this layer
if not pat_groups:
continue
# Step 1: Fuse — MergeModulelist (stack experts) per source pattern
fused_per_pattern = {}
original_keys_per_pattern: Dict[int, list[str]] = {}
num_experts = None
for pat_idx in sorted(pat_groups.keys()):
expert_data = pat_groups[pat_idx]
sorted_indices = sorted(expert_data.keys())
if num_experts is None:
num_experts = len(sorted_indices)
sorted_tensors = [expert_data[idx][1] for idx in sorted_indices]
original_keys_per_pattern[pat_idx] = [
expert_data[idx][0] for idx in sorted_indices
]
fused_per_pattern[src_patterns[pat_idx]] = torch.stack(
sorted_tensors, dim=0
)
# Apply remaining operations (Concatenate)
fused_tensor = None
has_concat = False
concat_dim = 1 # default
for op in converter.operations:
if isinstance(op, MergeModulelist):
pass # Already handled
elif isinstance(op, Concatenate):
has_concat = True
concat_dim = op.dim
tensors_to_cat = [
fused_per_pattern[sp]
for sp in src_patterns
if sp in fused_per_pattern
]
if len(tensors_to_cat) > 1:
fused_tensor = torch.cat(tensors_to_cat, dim=concat_dim)
elif tensors_to_cat:
fused_tensor = tensors_to_cat[0]
if not has_concat and len(fused_per_pattern) == 1:
fused_tensor = next(iter(fused_per_pattern.values()))
if fused_tensor is None:
continue
# Step 2: Build the fused key name and merge LoRA
fused_key = prefix + tgt_patterns[0]
# Apply NF4 roundtrip on the fused tensor (matching training dynamics)
do_nf4 = _should_nf4_roundtrip(
fused_key, fused_tensor, simulate_nf4, simulate_nf4_experts
)
if do_nf4:
fused_tensor = _simulate_nf4_roundtrip(
fused_tensor,
blocksize=nf4_blocksize,
compress_statistics=nf4_double_quant,
device=device,
)
# Try to find and merge LoRA weights for the fused key
lora_a, lora_b = find_lora_weights(lora_state, fused_key, weight_renamings)
if lora_a is not None and lora_b is not None:
LOG.debug(
f"Merging LoRA for fused key {fused_key}: {lora_a.shape}, {lora_b.shape}"
)
original_dtype = fused_tensor.dtype
magnitude = (
_find_dora_magnitude(lora_state, fused_key, weight_renamings)
if use_dora
else None
)
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
fused_tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
)
fused_tensor = (
(
fused_tensor.to(device).to(torch.float32)
+ delta.to(torch.float32)
)
.to(original_dtype)
.detach()
.cpu()
)
merged_count += 1
# Step 3: Save in fused format (runtime format) so that the merged
# model can be loaded directly without needing WeightConverter
# fusion during from_pretrained (which can OOM for large MoE models).
# Remove the original per-expert keys and save the fused tensor
# under the runtime key name.
for pat_idx in sorted(original_keys_per_pattern.keys()):
for ok in original_keys_per_pattern[pat_idx]:
result.pop(ok, None)
processed_keys.add(ok)
result[fused_key] = fused_tensor.detach().cpu()
processed_keys.add(fused_key)
return result, merged_count, processed_keys
def merge_lora_sharded_efficient(
base_model_path: Union[str, Path],
lora_adapter_path: Union[str, Path],
output_path: Union[str, Path],
device: str = "cpu",
safe_tensors: bool = True,
simulate_nf4: bool = False,
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
without loading the full model into memory.
Args:
simulate_nf4: Apply NF4 roundtrip to ALL weight tensors (for QLoRA)
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
(for quantize_moe_experts). Expert tensors are identified by having
"expert" in the key name and ndim >= 3.
"""
base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
output_path = Path(output_path)
if "/" in str(base_model_path) and not base_model_path.exists():
base_model_path = Path(snapshot_download(str(base_model_path)))
# Check for weight conversion requirements (transformers v5)
weight_renamings, weight_converters = _get_conversion_info(base_model_path)
if weight_renamings:
LOG.debug(f"Found {len(weight_renamings)} weight renamings for this model type")
if weight_converters:
LOG.debug(
f"Found {len(weight_converters)} weight converters (fuse/unfuse) for this model type. "
f"Will fuse→merge→unfuse within each shard."
)
os.makedirs(output_path, exist_ok=True)
config_file = lora_adapter_path / "adapter_config.json"
if not config_file.exists():
raise FileNotFoundError(f"LoRA config not found: {config_file}")
lora_config_dict = LoraConfig.from_json_file(str(config_file))
if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0:
raise ValueError("LoRA config 'r' must be > 0")
use_dora = bool(lora_config_dict.get("use_dora", False))
unsupported_methods = []
# Check for AdaLoRA (Adaptive LoRA)
if lora_config_dict.get("use_adalora", False):
unsupported_methods.append("AdaLoRA (Adaptive LoRA)")
# Check for VeRA (Vector-based Random Matrix Adaptation)
if lora_config_dict.get("use_vera", False):
unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)")
# Check for other advanced LoRA variants by task_type
task_type = lora_config_dict.get("task_type", "")
if task_type and task_type not in [
"CAUSAL_LM",
"SEQ_2_SEQ_LM",
"TOKEN_CLS",
"SEQ_CLS",
"QUESTION_ANS",
]:
unsupported_methods.append(f"Task type: {task_type}")
# Check for rank adaptation patterns (AdaLoRA indicators)
# Use .get() so empty dicts/None don't false-positive
if any(
lora_config_dict.get(key)
for key in ["rank_pattern", "alpha_pattern", "target_rank"]
):
unsupported_methods.append("AdaLoRA (rank adaptation detected)")
# Check for advanced initialization methods
init_lora_weights = lora_config_dict.get("init_lora_weights", "")
if init_lora_weights and init_lora_weights not in [
"gaussian",
"loftq",
True,
False,
]:
unsupported_methods.append(f"Advanced initialization: {init_lora_weights}")
if unsupported_methods:
methods_str = ", ".join(unsupported_methods)
raise NotImplementedError(
f"Memory-efficient LoRA merge only supports standard LoRA. "
f"Detected unsupported methods: {methods_str}. "
f"Please use the legacy merge method for advanced LoRA variants."
)
use_rslora = bool(lora_config_dict.get("use_rslora", False))
if use_rslora:
scale = float(lora_config_dict["lora_alpha"]) / math.sqrt(
float(lora_config_dict["r"])
)
else:
scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"])
LOG.debug(f"LoRA scale factor: {scale} (rslora={use_rslora})")
if simulate_nf4:
LOG.info(
"NF4 simulation enabled: base weights will undergo quantize→dequantize "
"roundtrip before LoRA merge to match QLoRA training dynamics"
)
lora_file = lora_adapter_path / "adapter_model.safetensors"
if not lora_file.exists():
lora_file = lora_adapter_path / "adapter_model.bin"
if not lora_file.exists():
raise FileNotFoundError(
f"LoRA adapter weights not found in {lora_adapter_path}"
)
LOG.debug(f"Loading LoRA weights from {lora_file}")
if lora_file.suffix == ".safetensors":
lora_state = safetensors.torch.load_file(lora_file)
else:
lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614
LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")
model_shards = get_model_shards(base_model_path)
if not model_shards:
raise FileNotFoundError(f"No model shards found in {base_model_path}")
LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}")
copy_non_model_files(base_model_path, output_path, model_shards)
merged_count = 0
total_tensors = 0
# Track weight_map for index regeneration: {tensor_key: shard_filename}
weight_map: Dict[str, str] = {}
for shard_path in tqdm(model_shards, desc="Merging shards"):
merged_tensors = {}
metadata = {}
# Load all tensors from the shard
if shard_path.suffix == ".safetensors":
with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
if hasattr(f, "metadata") and f.metadata():
metadata = f.metadata()
shard_tensors = {key: f.get_tensor(key) for key in f.keys()}
else:
shard_tensors = torch.load( # nosec B614: loading trusted model weights
shard_path, map_location="cpu", weights_only=True
)
total_tensors += len(shard_tensors)
# Step 1: Handle fused weight conversions (MoE experts) if applicable
fused_keys: set = set()
if weight_converters:
shard_tensors, fused_merged, fused_keys = _fuse_and_unfuse_with_merge(
shard_tensors,
weight_converters,
lora_state,
scale,
lora_config_dict,
device,
simulate_nf4=simulate_nf4,
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
)
merged_count += fused_merged
# Step 2: Merge remaining (non-fused) tensors with LoRA
# Skip keys already processed by fuse/unfuse to avoid double NF4 roundtrip
for key, tensor in shard_tensors.items():
if key in fused_keys:
merged_tensors[key] = tensor.detach().cpu()
continue
merged_tensor, was_merged = _merge_tensor_with_lora(
tensor,
key,
lora_state,
scale,
lora_config_dict,
device,
simulate_nf4=simulate_nf4,
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
)
merged_tensors[key] = merged_tensor
if was_merged:
merged_count += 1
output_shard_path = output_path / shard_path.name
merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
if safe_tensors:
if not str(output_shard_path).endswith(".safetensors"):
output_shard_path = output_path / (shard_path.stem + ".safetensors")
safetensors.torch.save_file(
merged_tensors, output_shard_path, metadata=metadata
)
else:
if shard_path.suffix == ".safetensors":
safetensors.torch.save_file(
merged_tensors, output_shard_path, metadata=metadata
)
else:
torch.save(merged_tensors, output_shard_path)
for tensor_key in merged_tensors:
weight_map[tensor_key] = output_shard_path.name
del merged_tensors, shard_tensors
if device != "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Regenerate weight-map index if the model was sharded
if len(model_shards) > 1 and weight_map:
import json as _json
index_name = (
"model.safetensors.index.json"
if safe_tensors
else "pytorch_model.bin.index.json"
)
index = {
"metadata": {"total_size": total_tensors},
"weight_map": weight_map,
}
with open(output_path / index_name, "w") as f:
_json.dump(index, f, indent=2)
LOG.debug(f"Wrote weight-map index: {index_name}")
if merged_count == 0:
LOG.warning(
"No LoRA weights were matched to base model tensors. "
"This may indicate a key name mismatch between the checkpoint format "
"and the LoRA adapter. Consider using merge_method: legacy."
)
LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors")

View File

@@ -164,6 +164,16 @@ def patch_peft_target_parameters_matching():
from peft.utils.integrations import init_empty_weights
from peft.utils.other import _get_submodules
# Mapping from unfused parameter names to their fused equivalents.
# When a model stores fused weights (e.g. gate_up_proj) but the user
# specifies unfused names (gate_proj, up_proj), we auto-expand so the
# fused parameter is also targeted. The original unfused names are kept
# in the set so that models that do NOT fuse still work.
_UNFUSED_TO_FUSED: dict[str, str] = {
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
@@ -176,10 +186,43 @@ def patch_peft_target_parameters_matching():
continue
for target in original_targets:
mod_path, _, param_name = target.rpartition(".")
if (
if not (
module_name == mod_path or module_name.endswith("." + mod_path)
) and hasattr(module, param_name):
):
continue
if hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
elif param_name in _UNFUSED_TO_FUSED:
# The model uses fused weights (e.g. gate_up_proj) but the
# user specified unfused names (gate_proj / up_proj).
fused_name = _UNFUSED_TO_FUSED[param_name]
if hasattr(module, fused_name):
if fused_name not in expanded:
LOG.warning(
"target_parameter '%s' not found on %s, "
"but fused equivalent '%s' exists — adding "
"it automatically.",
param_name,
module_name,
fused_name,
)
expanded.add(f"{module_name}.{fused_name}")
else:
LOG.warning(
"target_parameter '%s' not found on %s and no "
"fused equivalent exists either — skipping.",
param_name,
module_name,
)
else:
LOG.warning(
"target_parameter '%s' not found on %s — skipping. "
"Check that the parameter name matches the model's "
"weight names.",
param_name,
module_name,
)
target_names_set = expanded

View File

@@ -155,6 +155,12 @@ class LoraConfig(BaseModel):
)
merge_lora: bool | None = None
merge_method: Literal["legacy", "memory_efficient"] | None = Field(
default="memory_efficient",
json_schema_extra={
"description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory."
},
)
@model_validator(mode="before")
@classmethod

View File

@@ -1,8 +1,18 @@
import json
import math
from unittest.mock import Mock, patch
import safetensors.torch
import torch
from axolotl.cli.merge_lora import do_merge_lora
from axolotl.cli.utils.lora_merge import (
_build_peft_layer_and_get_delta,
_find_param_wrapper_lora,
_merge_tensor_with_lora,
find_lora_weights,
merge_lora_sharded_efficient,
)
from axolotl.utils.dict import DictDefault
@@ -132,6 +142,7 @@ class TestAdapterMergeUnmerge:
"torch_dtype": torch.float32,
"local_rank": 0,
"output_dir": str(tmp_path),
"merge_method": "legacy",
}
)
@@ -167,6 +178,7 @@ class TestAdapterMergeUnmerge:
"save_safetensors": True,
"output_dir": str(tmp_path),
"local_rank": 0,
"merge_method": "legacy",
}
)
@@ -179,3 +191,611 @@ class TestAdapterMergeUnmerge:
do_merge_lora(cfg=cfg)
assert mock_load.called
class TestEfficientMerge:
"""Test suite for memory-efficient shard-by-shard LoRA merge."""
def _make_adapter(self, tmp_path, r=8, alpha=16, use_dora=False, use_rslora=False):
"""Create a minimal adapter directory with config + weights."""
adapter_dir = tmp_path / "adapter"
adapter_dir.mkdir()
config = {
"r": r,
"lora_alpha": alpha,
"target_modules": ["q_proj", "v_proj"],
"task_type": "CAUSAL_LM",
"bias": "none",
"use_dora": use_dora,
"use_rslora": use_rslora,
}
(adapter_dir / "adapter_config.json").write_text(json.dumps(config))
return adapter_dir, config
def _make_base_model(self, tmp_path, hidden=32):
"""Create a minimal base model directory with one shard."""
model_dir = tmp_path / "base_model"
model_dir.mkdir()
weights = {
"model.layers.0.self_attn.q_proj.weight": torch.randn(hidden, hidden),
"model.layers.0.self_attn.v_proj.weight": torch.randn(hidden, hidden),
"model.embed_tokens.weight": torch.randn(100, hidden),
}
safetensors.torch.save_file(weights, model_dir / "model.safetensors")
# Minimal config files
(model_dir / "config.json").write_text("{}")
return model_dir, weights
def test_find_lora_weights(self):
lora_state = {
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
8, 32
),
"base_model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
32, 8
),
}
a, b = find_lora_weights(lora_state, "layers.0.self_attn.q_proj.weight")
assert a is not None and b is not None
assert a.shape == (8, 32)
a, b = find_lora_weights(lora_state, "layers.0.self_attn.v_proj.weight")
assert a is None and b is None
def test_merge_tensor_basic(self):
hidden = 32
r = 8
alpha = 16
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
scale = alpha / r
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
}
config = {"r": r, "lora_alpha": alpha}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.q_proj.weight", lora_state, scale, config, "cpu"
)
assert was_merged
expected = base + scale * (lora_b @ lora_a)
assert torch.allclose(merged, expected, atol=1e-5)
def test_merge_tensor_rslora_scale(self):
"""RSLoRA should use alpha/sqrt(r) as scaling factor."""
r = 16
alpha = 32
standard_scale = alpha / r # 2.0
rslora_scale = alpha / math.sqrt(r) # 8.0
assert rslora_scale != standard_scale
assert abs(rslora_scale - 8.0) < 1e-6
def test_sharded_efficient_merge(self, tmp_path):
"""End-to-end test of shard-by-shard merge."""
hidden = 32
r = 8
alpha = 16
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
# Create LoRA weights
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
hidden, r
),
"base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn(
hidden, r
),
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
# Verify output exists and has merged weights
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
scale = alpha / r
q_key = "model.layers.0.self_attn.q_proj.weight"
expected_q = base_weights[q_key] + scale * (
lora_state[f"base_model.model.{q_key[:-7]}.lora_B.weight"]
@ lora_state[f"base_model.model.{q_key[:-7]}.lora_A.weight"]
)
assert torch.allclose(merged[q_key], expected_q, atol=1e-5)
# Embedding should be unchanged
assert torch.equal(
merged["model.embed_tokens.weight"],
base_weights["model.embed_tokens.weight"],
)
def test_dora_merge(self):
"""DoRA merge applies magnitude normalization via PEFT."""
hidden = 32
r = 8
alpha = 16
scale = alpha / r
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
magnitude = torch.randn(hidden).abs() + 0.1
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
"base_model.model.layer.q_proj.lora_magnitude_vector": magnitude,
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base,
"layer.q_proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)
assert was_merged
# The merge should differ from both base and base+delta (DoRA applies normalization)
delta = scale * (lora_b @ lora_a)
assert not torch.allclose(merged, base, atol=1e-3)
assert not torch.allclose(merged, base + delta, atol=1e-3)
def test_fuse_unfuse_moe_merge(self):
"""Test fuse→merge→unfuse for MoE expert weights (WeightConverter path)."""
from axolotl.cli.utils.lora_merge import _fuse_and_unfuse_with_merge
hidden = 16
intermediate = 32
num_experts = 4
r = 4
alpha = 8
scale = alpha / r
# Simulate checkpoint format: per-expert separate tensors
shard_tensors = {}
for i in range(num_experts):
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = (
torch.randn(intermediate, hidden)
)
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = (
torch.randn(intermediate, hidden)
)
shard_tensors[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = (
torch.randn(hidden, intermediate)
)
shard_tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn(
hidden, hidden
)
# LoRA targets the fused key (runtime format)
lora_state = {
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight": torch.randn(
intermediate * 2, r
),
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_A.weight": torch.randn(
r, intermediate
),
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_B.weight": torch.randn(
hidden, r
),
}
# Build converters matching qwen2_moe pattern
from transformers.core_model_loading import (
Concatenate,
MergeModulelist,
WeightConverter,
)
converters = [
WeightConverter(
source_patterns=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
]
config = {"r": r, "lora_alpha": alpha}
result, merged_count, processed_keys = _fuse_and_unfuse_with_merge(
shard_tensors, converters, lora_state, scale, config, "cpu"
)
# Should have merged 2 LoRA targets (gate_up_proj and down_proj)
assert merged_count == 2
# Processed keys include original per-expert keys (removed) + fused keys (added)
assert len(processed_keys) > 0
# Output should be in fused format (runtime keys)
assert "model.layers.0.mlp.experts.gate_up_proj" in result
assert "model.layers.0.mlp.experts.down_proj" in result
# Per-expert keys should be removed
for i in range(num_experts):
assert f"model.layers.0.mlp.experts.{i}.gate_proj.weight" not in result
# Non-expert tensor should be passed through
assert "model.layers.0.self_attn.q_proj.weight" in result
# Verify fused tensors are 3D (stacked experts)
gate_up = result["model.layers.0.mlp.experts.gate_up_proj"]
assert gate_up.ndim == 3
assert gate_up.shape[0] == num_experts # [num_experts, intermediate*2, hidden]
# Verify the fused LoRA delta was applied correctly
# Reconstruct the fused base (stack per-expert, concat gate+up)
gate_stack = torch.stack(
[
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"]
for i in range(num_experts)
]
)
up_stack = torch.stack(
[
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"]
for i in range(num_experts)
]
)
base_fused = torch.cat([gate_stack, up_stack], dim=1)
lora_a = lora_state[
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight"
]
lora_b = lora_state[
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight"
]
expected_fused = base_fused + scale * (lora_b @ lora_a)
assert torch.allclose(gate_up, expected_fused, atol=1e-5)
def test_param_wrapper_merge_math(self):
"""ParamWrapper merge via PEFT's get_delta_weight matches manual einsum."""
num_experts = 4
r = 2
in_features = 8
out_features = 4
alpha = 4
base = torch.randn(num_experts, in_features, out_features)
lora_a = torch.randn(r * num_experts, in_features)
lora_b = torch.randn(out_features, r * num_experts)
config = {"r": r, "lora_alpha": alpha}
delta = _build_peft_layer_and_get_delta(
lora_a, lora_b, config, base, is_param_wrapper=True
)
assert delta.shape == base.shape
merged = base + delta
# Verify against manual einsum
scale = alpha / r
wa = lora_a.reshape(num_experts, r, in_features)
wb = lora_b.reshape(out_features, r, num_experts)
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
for e in range(num_experts):
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
f"Expert {e} mismatch"
)
def test_param_wrapper_nesting_dim_filter(self):
"""_find_param_wrapper_lora skips wrong-dimension LoRA at outer level."""
num_experts = 4
r = 2
# Outer LoRA (gate_up_proj): A=[r*E, 8], B=[16, r*E]
# Inner LoRA (down_proj via base_layer): A=[r*E, 16], B=[8, r*E]
lora_state = {
"base_model.model.mod.experts.lora_A.weight": torch.randn(
r * num_experts, 8
),
"base_model.model.mod.experts.lora_B.weight": torch.randn(
16, r * num_experts
),
"base_model.model.mod.experts.base_layer.lora_A.weight": torch.randn(
r * num_experts, 16
),
"base_model.model.mod.experts.base_layer.lora_B.weight": torch.randn(
8, r * num_experts
),
}
# gate_up_proj shape [4, 8, 16] — should match outer LoRA
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.gate_up_proj", tensor_shape=(4, 8, 16)
)
assert a is not None and name == "gate_up_proj"
assert a.shape == (r * num_experts, 8) # outer
# down_proj shape [4, 16, 8] — outer dims don't match, should find inner
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.down_proj", tensor_shape=(4, 16, 8)
)
assert a is not None and name == "down_proj"
assert a.shape == (r * num_experts, 16) # inner (base_layer)
# shape that matches neither — should return None
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.other", tensor_shape=(4, 99, 99)
)
assert a is None
def test_find_lora_weights_with_renamings(self):
"""Weight renamings let checkpoint keys match LoRA keys."""
lora_state = {
"base_model.model.layers.0.mlp.fc1.lora_A.weight": torch.randn(8, 32),
"base_model.model.layers.0.mlp.fc1.lora_B.weight": torch.randn(32, 8),
}
# Direct lookup fails (checkpoint has "ff0", LoRA has "fc1")
a, b = find_lora_weights(lora_state, "layers.0.mlp.ff0.weight")
assert a is None
# With renaming ff0 → fc1, it should match
a, b = find_lora_weights(
lora_state, "layers.0.mlp.ff0.weight", weight_renamings={"ff0": "fc1"}
)
assert a is not None
assert a.shape == (8, 32)
def test_unmatched_tensors_pass_through(self):
"""Tensors with no matching LoRA are returned unchanged."""
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": torch.randn(8, 32),
"base_model.model.layer.q_proj.lora_B.weight": torch.randn(32, 8),
}
# 1D tensor (layernorm) — never matched
ln = torch.randn(32)
merged, was_merged = _merge_tensor_with_lora(
ln, "layer.norm.weight", lora_state, 2.0, {}, "cpu"
)
assert not was_merged
assert torch.equal(merged, ln)
# 2D tensor with no matching key
unrelated = torch.randn(64, 32)
merged, was_merged = _merge_tensor_with_lora(
unrelated, "layer.other_proj.weight", lora_state, 2.0, {}, "cpu"
)
assert not was_merged
assert torch.equal(merged, unrelated)
def test_fan_in_fan_out_transpose(self):
"""fan_in_fan_out config transposes the LoRA delta."""
hidden = 16
r = 4
alpha = 4 # scale = 1.0
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
lora_state = {
"base_model.model.layer.proj.lora_A.weight": lora_a,
"base_model.model.layer.proj.lora_B.weight": lora_b,
}
config_normal = {"r": r, "lora_alpha": alpha}
config_fif = {"r": r, "lora_alpha": alpha, "fan_in_fan_out": True}
merged_normal, _ = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, 1.0, config_normal, "cpu"
)
merged_fif, _ = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, 1.0, config_fif, "cpu"
)
delta = (alpha / r) * (lora_b @ lora_a)
assert torch.allclose(merged_normal, base + delta, atol=1e-5)
assert torch.allclose(merged_fif, base + delta.T, atol=1e-5)
assert not torch.allclose(merged_normal, merged_fif, atol=1e-5)
def test_rslora_end_to_end(self, tmp_path):
"""RSLoRA adapter uses alpha/sqrt(r) scaling in sharded merge."""
hidden = 16
r = 16
alpha = 32
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_rslora=True)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
rslora_scale = alpha / math.sqrt(r) # 8.0, not 2.0
q_key = "model.layers.0.self_attn.q_proj.weight"
expected = base_weights[q_key] + rslora_scale * (lora_b @ lora_a)
assert torch.allclose(merged[q_key], expected, atol=1e-5)
# Confirm it differs from standard scale
wrong_scale = alpha / r # 2.0
wrong_expected = base_weights[q_key] + wrong_scale * (lora_b @ lora_a)
assert not torch.allclose(merged[q_key], wrong_expected, atol=1e-3)
def test_multi_shard_index_json(self, tmp_path):
"""Multi-shard merge generates a correct weight-map index."""
hidden = 16
r = 4
alpha = 8
model_dir = tmp_path / "base_model"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}")
# Create 2 shards
shard1 = {"model.layers.0.weight": torch.randn(hidden, hidden)}
shard2 = {"model.layers.1.weight": torch.randn(hidden, hidden)}
safetensors.torch.save_file(
shard1, model_dir / "model-00001-of-00002.safetensors"
)
safetensors.torch.save_file(
shard2, model_dir / "model-00002-of-00002.safetensors"
)
# Write a base model index (will be skipped by copy_non_model_files)
base_index = {
"metadata": {},
"weight_map": {
"model.layers.0.weight": "model-00001-of-00002.safetensors",
"model.layers.1.weight": "model-00002-of-00002.safetensors",
},
}
(model_dir / "model.safetensors.index.json").write_text(json.dumps(base_index))
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
safetensors.torch.save_file({}, adapter_dir / "adapter_model.safetensors")
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
# Verify index was generated
index_path = output_dir / "model.safetensors.index.json"
assert index_path.exists()
with open(index_path) as f:
idx = json.load(f)
assert "weight_map" in idx
assert len(idx["weight_map"]) == 2
# Each key should map to a shard that exists
for _key, shard_name in idx["weight_map"].items():
assert (output_dir / shard_name).exists(), f"Missing shard: {shard_name}"
def test_dora_end_to_end(self, tmp_path):
"""DoRA merge through the full sharded merge pipeline."""
hidden = 16
r = 4
alpha = 8
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_dora=True)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
magnitude = torch.randn(hidden).abs() + 0.1
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
"base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector": magnitude,
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
q_key = "model.layers.0.self_attn.q_proj.weight"
# Use PEFT's own get_delta_weight as the reference
delta = _build_peft_layer_and_get_delta(
lora_a,
lora_b,
{"r": r, "lora_alpha": alpha, "use_dora": True},
base_weights[q_key],
magnitude=magnitude,
)
expected = base_weights[q_key] + delta
assert torch.allclose(merged[q_key], expected, atol=1e-5)
# Verify it differs from standard (non-DoRA) merge
standard_delta = _build_peft_layer_and_get_delta(
lora_a,
lora_b,
{"r": r, "lora_alpha": alpha},
base_weights[q_key],
)
assert not torch.allclose(delta, standard_delta, atol=1e-3)
# v_proj has no LoRA weights — should be unchanged
v_key = "model.layers.0.self_attn.v_proj.weight"
assert torch.equal(merged[v_key], base_weights[v_key]), (
"v_proj should be unchanged (no LoRA weights for it)"
)
def test_dora_missing_magnitude_falls_back(self):
"""DoRA without magnitude vector falls back to standard LoRA merge."""
hidden = 16
r = 4
alpha = 8
scale = alpha / r
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
# No magnitude vector in lora_state
lora_state = {
"base_model.model.layer.proj.lora_A.weight": lora_a,
"base_model.model.layer.proj.lora_B.weight": lora_b,
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
)
assert was_merged
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
# which produces a result different from plain W + scale * B @ A.
# Just verify it was merged (not unchanged).
assert not torch.equal(merged, base)