From b55706b9f6286003bd2385832396042e901d814e Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:11:32 +0530 Subject: [PATCH] feat:merge-lora iterate through bins without loading (#3095) * merge_method added * merge_efficient core implement * Update src/axolotl/cli/merge_lora.py Co-authored-by: Wing Lian * Update src/axolotl/utils/lora_merge_efficient.py Co-authored-by: Wing Lian * standard to leagcy + rstrip + try/except for do_merge_lora_efficient(cfg=cfg) * fix: 'dict' object has no attribute 'lora_alpha' * into -> debug * lint * lint2 * moved everythign to cpu + peformance improvments * lint * Update src/axolotl/cli/merge_lora.py Co-authored-by: Dan Saunders * Update src/axolotl/cli/merge_lora.py Co-authored-by: Dan Saunders * string handeling + try except remove * merge_method -> merge_lora_methods * remove duplicate cal + safetensor + move to lora_merge.py * lint * handle quant-dequant, handle experts * fix parameter merging and prefer peft's native merge logic per module --------- Co-authored-by: Wing Lian Co-authored-by: Dan Saunders --- src/axolotl/cli/merge_lora.py | 85 ++- src/axolotl/cli/utils/lora_merge.py | 982 +++++++++++++++++++++++++++ src/axolotl/monkeypatch/moe_quant.py | 47 +- src/axolotl/utils/schemas/peft.py | 6 + tests/utils/lora/test_merge_lora.py | 620 +++++++++++++++++ 5 files changed, 1735 insertions(+), 5 deletions(-) create mode 100644 src/axolotl/cli/utils/lora_merge.py diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index bc2dc84c7..dae9f317d 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -4,9 +4,11 @@ from pathlib import Path from typing import Union import fire +import torch from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -17,12 +19,26 @@ LOG = get_logger(__name__) @send_errors def do_merge_lora(*, cfg: DictDefault) -> None: """ - Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config - along with the LoRA adapters to combine them into a single base model. + Merges LoRA adapters with base model using either memory-efficient or legacy approach. Args: cfg: Dictionary mapping `axolotl` config keys to values. """ + merge_method = str(getattr(cfg, "merge_method", "memory_efficient")) + if merge_method == "legacy": + LOG.debug("Using legacy LoRA merging method...") + _do_merge_lora_legacy(cfg=cfg) + else: + LOG.debug("Using memory-efficient LoRA merging method...") + _do_merge_lora_efficient(cfg=cfg) + + +def _do_merge_lora_legacy(*, cfg: DictDefault) -> None: + """ + Legacy LoRA merging using merge_and_unload. + Loads the full model into memory before merging. + """ + LOG.debug("Using legacy LoRA merging method...") model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) LOG.info("Running merge of LoRA with base model...") @@ -52,6 +68,58 @@ def do_merge_lora(*, cfg: DictDefault) -> None: processor.save_pretrained(str(Path(cfg.output_dir) / "merged")) +def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: + """ + Memory-efficient LoRA merging using shard-by-shard processing. + Does not load the full model into memory. + + Supports standard LoRA, RSLoRA, and DoRA. Unsupported methods (AdaLoRA, VeRA) + will raise NotImplementedError — use legacy method for those. + """ + LOG.debug("Using memory-efficient LoRA merging method...") + + output_path = Path(cfg.output_dir) / "merged" + safe_tensors = getattr(cfg, "save_safetensors", True) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Detect NF4 quantization from config to simulate QLoRA training dynamics. + # Check both current and original (pre-override) config values since do_cli + # forces load_in_4bit=False for the legacy path. + simulate_nf4 = bool( + getattr(cfg, "load_in_4bit", False) + or getattr(cfg, "_original_load_in_4bit", False) + or getattr(cfg, "adapter", None) == "qlora" + or getattr(cfg, "_original_adapter", None) == "qlora" + ) + + bnb_config_kwargs = getattr(cfg, "bnb_config_kwargs", None) or {} + nf4_blocksize = bnb_config_kwargs.get("blocksize", None) + nf4_double_quant = bnb_config_kwargs.get( + "bnb_4bit_use_double_quant", + getattr(cfg, "bnb_4bit_use_double_quant", True), + ) + + # Detect MoE expert quantization + simulate_nf4_experts = bool( + getattr(cfg, "quantize_moe_experts", False) + or getattr(cfg, "_original_quantize_moe_experts", False) + ) + + merge_lora_sharded_efficient( + base_model_path=cfg.base_model, + lora_adapter_path=cfg.lora_model_dir, + output_path=output_path, + safe_tensors=safe_tensors, + device=device, + simulate_nf4=simulate_nf4, + simulate_nf4_experts=simulate_nf4_experts, + nf4_blocksize=nf4_blocksize, + nf4_double_quant=nf4_double_quant, + ) + + LOG.debug("Memory-efficient LoRA merge completed successfully!") + + def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: """ Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various @@ -66,6 +134,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: ValueError: If target directory for LoRA merged model does not exist. """ + # Pre-load config to detect original quantization settings before overrides + raw_cfg = load_cfg(config, **kwargs) + original_load_in_4bit = getattr(raw_cfg, "load_in_4bit", False) + original_adapter = getattr(raw_cfg, "adapter", None) + original_quantize_moe_experts = getattr(raw_cfg, "quantize_moe_experts", False) + parsed_cfg = load_cfg( config, merge_lora=True, @@ -80,11 +154,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: **kwargs, ) + # Stash original quantization settings for NF4 simulation in efficient merge + parsed_cfg._original_load_in_4bit = original_load_in_4bit + parsed_cfg._original_adapter = original_adapter + parsed_cfg._original_quantize_moe_experts = original_quantize_moe_experts + if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir: parsed_cfg.lora_model_dir = parsed_cfg.output_dir if not Path(parsed_cfg.lora_model_dir).exists(): raise ValueError( - f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist." + f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`" ) do_merge_lora(cfg=parsed_cfg) diff --git a/src/axolotl/cli/utils/lora_merge.py b/src/axolotl/cli/utils/lora_merge.py new file mode 100644 index 000000000..339e41e2d --- /dev/null +++ b/src/axolotl/cli/utils/lora_merge.py @@ -0,0 +1,982 @@ +import gc +import math +import os +import shutil +from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors +import safetensors.torch +import torch +from huggingface_hub import snapshot_download +from peft import LoraConfig +from tqdm import tqdm + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _simulate_nf4_roundtrip( + tensor: torch.Tensor, + blocksize: Optional[int] = None, + compress_statistics: bool = True, + device: Optional[Union[str, torch.device]] = None, +) -> torch.Tensor: + """ + Simulate NF4 quantization roundtrip to match QLoRA training dynamics. + + During QLoRA training, base weights are quantized to NF4 and dequantized on-the-fly + for each forward pass. The LoRA adapters learn to compensate for the quantization + noise in the dequantized weights. To match this at merge time, we apply the same + quantize → dequantize roundtrip so the merged result reflects what the model saw + during training. + + Args: + tensor: Base model weight tensor (fp16/bf16/fp32) + blocksize: NF4 quantization block size (default: bitsandbytes default) + compress_statistics: Whether to use double quantization + device: Device for quantization computation. bitsandbytes requires a + CUDA device; defaults to "cuda" when available. + + Returns: + Tensor after NF4 quantize → dequantize roundtrip, in original dtype + """ + import bitsandbytes.functional as bnb_F + + quant_device: torch.device + if device is None: + quant_device = torch.device("cuda") + elif isinstance(device, str): + quant_device = torch.device(device) + else: + quant_device = device + + if quant_device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError( + "NF4 simulation requires CUDA but no GPU is available. " + "Either run on a machine with a GPU or disable NF4 simulation." + ) + + original_dtype = tensor.dtype + original_shape = tensor.shape + + # bitsandbytes requires float32 input for quantization and contiguous+CUDA tensor + flat = tensor.reshape(-1).to(torch.float32).contiguous().to(quant_device) + + quant_kwargs = { + "quant_type": "nf4", + "compress_statistics": compress_statistics, + } + if blocksize is not None: + quant_kwargs["blocksize"] = blocksize + + quantized, quant_state = bnb_F.quantize_4bit(flat, **quant_kwargs) + dequantized = bnb_F.dequantize_4bit(quantized, quant_state, quant_type="nf4") + + return dequantized.reshape(original_shape).to(original_dtype).cpu() + + +def find_lora_weights( + lora_state: Dict[str, torch.Tensor], + key: str, + weight_renamings: Optional[Dict[str, str]] = None, +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Find corresponding LoRA A and B weights for a given key. + + Also tries keys after applying weight renamings (from transformers v5 + conversion mappings) in case the checkpoint key names differ from the + runtime model key names used by the LoRA adapter. + """ + import re + + clean_key = key[:-7] if key.endswith(".weight") else key + + # Try the direct key first + a_key = f"base_model.model.{clean_key}.lora_A.weight" + b_key = f"base_model.model.{clean_key}.lora_B.weight" + + lora_a = lora_state.get(a_key) + lora_b = lora_state.get(b_key) + + if lora_a is not None and lora_b is not None: + return lora_a, lora_b + + # Try renamed keys (checkpoint format → runtime format) + if weight_renamings: + for src_pattern, tgt_pattern in weight_renamings.items(): + renamed_key = re.sub(src_pattern, tgt_pattern, clean_key) + if renamed_key != clean_key: + a_key = f"base_model.model.{renamed_key}.lora_A.weight" + b_key = f"base_model.model.{renamed_key}.lora_B.weight" + lora_a = lora_state.get(a_key) + lora_b = lora_state.get(b_key) + if lora_a is not None and lora_b is not None: + return lora_a, lora_b + + return None, None + + +def _find_param_wrapper_lora( + lora_state: Dict[str, torch.Tensor], + key: str, + tensor_shape: Optional[tuple] = None, +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[str]]: + """ + Find LoRA weights from a ParamWrapper (lora_target_parameters) that targets + a parent module containing this weight as a sub-parameter. + + For example, base weight key 'model.layers.0.mlp.experts.down_proj' may have + LoRA at 'base_model.model.model.layers.0.mlp.experts.lora_A.weight' (targeting + the 'experts' module with 'down_proj' as the parameter_name). + + When tensor_shape is provided, validates that the LoRA dimensions match the + target tensor (important when multiple ParamWrappers are nested and each + nesting level has different LoRA dimensions). + + Returns (lora_A, lora_B, parameter_name) or (None, None, None). + """ + clean_key = key[:-7] if key.endswith(".weight") else key + # Strip trailing parameter name to get the parent module path + # e.g., "model.layers.0.mlp.experts.down_proj" → parent="model.layers.0.mlp.experts", param="down_proj" + parts = clean_key.rsplit(".", 1) + if len(parts) != 2: + return None, None, None + + parent_key, param_name = parts + + # PEFT's ParamWrapper nesting: when multiple parameters are targeted on + # the same module, it nests wrappers. The outer wrapper's LoRA is at + # parent.lora_A/B and inner wrappers use parent.base_layer.lora_A/B, + # parent.base_layer.base_layer.lora_A/B, etc. + prefixes_to_try = [ + f"base_model.model.{parent_key}", + ] + # Walk up .base_layer nesting levels (typically 1-2 deep) + for depth in range(1, 4): + bl = ".base_layer" * depth + prefixes_to_try.append(f"base_model.model.{parent_key}{bl}") + + for prefix in prefixes_to_try: + a_key = f"{prefix}.lora_A.weight" + b_key = f"{prefix}.lora_B.weight" + lora_a = lora_state.get(a_key) + lora_b = lora_state.get(b_key) + if lora_a is None or lora_b is None: + continue + + # When tensor_shape is given, verify dimensions match before returning. + # This prevents returning a mismatched LoRA from a different nesting level. + if tensor_shape is not None and len(tensor_shape) >= 3: + num_experts = tensor_shape[0] + if not ( + lora_a.shape[0] == lora_b.shape[1] + and lora_a.shape[0] % num_experts == 0 + and lora_a.shape[1] == tensor_shape[1] + and lora_b.shape[0] == tensor_shape[2] + ): + continue # Dimensions don't match, try next nesting level + + return lora_a, lora_b, param_name + + return None, None, None + + +def _build_peft_layer_and_get_delta( + lora_a: torch.Tensor, + lora_b: torch.Tensor, + lora_config_dict: Dict, + base_tensor: torch.Tensor, + adapter_name: str = "default", + is_param_wrapper: bool = False, + magnitude: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Use PEFT's own layer classes to compute the LoRA delta weight. + + Instead of re-implementing the merge math for every LoRA variant, this + constructs a lightweight PEFT layer, loads the A/B weights, and calls + ``get_delta_weight`` (or ``merge`` for DoRA) which handles standard LoRA, + RSLoRA, DoRA, and ParamWrapper (expert-blocked) LoRA. + + Returns the delta tensor (same shape as base_tensor). + """ + import warnings + + import torch.nn as nn + + r_total = lora_a.shape[0] + in_features = lora_a.shape[1] + out_features = lora_b.shape[0] + lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1)) + use_rslora = bool(lora_config_dict.get("use_rslora", False)) + use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None + + if is_param_wrapper: + from peft.tuners.lora.layer import ParamWrapper + + num_experts = base_tensor.shape[0] + r = r_total // num_experts + + class _FakeModule(nn.Module): + pass + + fake = _FakeModule() + fake.register_parameter( + "weight", nn.Parameter(base_tensor.clone(), requires_grad=False) + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + layer = ParamWrapper( + fake, + adapter_name=adapter_name, + parameter_name="weight", + r=r, + lora_alpha=lora_alpha, + use_rslora=use_rslora, + ) + layer.lora_A[adapter_name].weight.data = lora_a + layer.lora_B[adapter_name].weight.data = lora_b + return layer.get_delta_weight(adapter_name) + else: + from peft.tuners.lora.layer import Linear as LoraLinear + + base_layer = nn.Linear(in_features, out_features, bias=False) + base_layer.weight.data = base_tensor.clone() + + fan_in_fan_out = bool( + lora_config_dict.get("fan_in_fan_out", False) + or lora_config_dict.get("lora_fan_in_fan_out", False) + ) + + layer = LoraLinear( + base_layer, + adapter_name=adapter_name, + r=r_total, + lora_alpha=lora_alpha, + fan_in_fan_out=fan_in_fan_out, + use_rslora=use_rslora, + use_dora=use_dora, + ) + layer.lora_A[adapter_name].weight.data = lora_a + layer.lora_B[adapter_name].weight.data = lora_b + + if use_dora: + # DoRA merges magnitude normalization into the weight directly. + # Use PEFT's merge() which handles DoRA internally, then + # compute the delta as merged_weight - original_weight. + mag_layer = layer.lora_magnitude_vector[adapter_name] + mag_layer.weight = nn.Parameter(magnitude) + layer.merge(adapter_names=[adapter_name]) + return base_layer.weight.data - base_tensor + + return layer.get_delta_weight(adapter_name) + + +def get_model_shards(model_path: Path) -> list[Path]: + """Find all model shards in the given path.""" + shards: list[Path] = [] + + patterns = ["model*.safetensors", "pytorch_model*.bin"] + + for pattern in patterns: + shards.extend(model_path.glob(pattern)) + if shards: + break + + return sorted(shards) + + +def copy_non_model_files( + input_path: Path, output_path: Path, model_shards: list[Path] +) -> None: + """ + Copy all non-model files to the output directory. + + Args: + input_path: Source directory + output_path: Destination directory + model_shards: List of model shard files to skip + """ + LOG.info("Copying non-model files to output directory...") + + shard_names = {shard.name for shard in model_shards} + + for filepath in input_path.glob("*"): + if filepath.is_dir(): + continue + if filepath.name in shard_names: + continue + if ( + filepath.name.startswith("model") and filepath.suffix == ".safetensors" + ) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"): + continue + if filepath.suffix == ".gguf": + continue + # Skip weight-map index files — they reference shard filenames that may + # change during the merge (e.g. .bin → .safetensors). A correct index + # is regenerated after all shards have been written. + if filepath.name.endswith(".index.json"): + continue + + LOG.debug(f"Copying {filepath.name} to output") + shutil.copy2(filepath, output_path) + + +def _find_dora_magnitude( + lora_state: Dict[str, torch.Tensor], + key: str, + weight_renamings: Optional[Dict[str, str]] = None, +) -> Optional[torch.Tensor]: + """ + Find DoRA magnitude vector for a given key. + """ + import re + + clean_key = key[:-7] if key.endswith(".weight") else key + mag_key = f"base_model.model.{clean_key}.lora_magnitude_vector" + result = lora_state.get(mag_key) + if result is not None: + return result + + if weight_renamings: + for src_pattern, tgt_pattern in weight_renamings.items(): + renamed_key = re.sub(src_pattern, tgt_pattern, clean_key) + if renamed_key != clean_key: + mag_key = f"base_model.model.{renamed_key}.lora_magnitude_vector" + result = lora_state.get(mag_key) + if result is not None: + return result + + return None + + +def _should_nf4_roundtrip( + key: str, + tensor: torch.Tensor, + simulate_nf4: bool, + simulate_nf4_experts: bool, +) -> bool: + """Determine if a tensor should undergo NF4 quantization roundtrip.""" + if tensor.ndim < 2: + return False + if simulate_nf4: + return True + if simulate_nf4_experts and tensor.ndim >= 3 and "expert" in key.lower(): + return True + return False + + +def _merge_tensor_with_lora( + tensor: torch.Tensor, + key: str, + lora_state: Dict[str, torch.Tensor], + scale: float, + lora_config_dict: Dict, + device: str, + simulate_nf4: bool = False, + simulate_nf4_experts: bool = False, + nf4_blocksize: Optional[int] = None, + nf4_double_quant: bool = True, + use_dora: bool = False, + weight_renamings: Optional[Dict[str, str]] = None, +) -> tuple[torch.Tensor, bool]: + """ + Helper function to merge a single tensor with its corresponding LoRA weights. + + Args: + tensor: Base model tensor + key: Tensor key/name + lora_state: Dictionary containing LoRA weights + scale: LoRA scaling factor (alpha/r) + lora_config_dict: LoRA configuration dictionary + device: Device to perform computations on + simulate_nf4: Whether to simulate NF4 quantization roundtrip for all weights + simulate_nf4_experts: Whether to simulate NF4 roundtrip for MoE expert tensors only + nf4_blocksize: Block size for NF4 quantization + nf4_double_quant: Whether to use double quantization + use_dora: Whether to apply DoRA (Weight-Decomposed LoRA) merging + weight_renamings: Optional key renamings from transformers conversion mapping + + Returns: + Tuple of (merged tensor, whether LoRA was applied) + """ + lora_a, lora_b = find_lora_weights(lora_state, key, weight_renamings) + + do_nf4 = _should_nf4_roundtrip(key, tensor, simulate_nf4, simulate_nf4_experts) + + if lora_a is not None and lora_b is not None: + LOG.debug(f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}") + + original_dtype = tensor.dtype + + # Simulate NF4 quantization roundtrip to match QLoRA training dynamics + if do_nf4: + tensor = _simulate_nf4_roundtrip( + tensor, + blocksize=nf4_blocksize, + compress_statistics=nf4_double_quant, + device=device, + ) + + magnitude = ( + _find_dora_magnitude(lora_state, key, weight_renamings) + if use_dora + else None + ) + delta = _build_peft_layer_and_get_delta( + lora_a.to(device), + lora_b.to(device), + lora_config_dict, + tensor.to(device), + magnitude=magnitude.to(device) if magnitude is not None else None, + ) + merged_tensor = ( + (tensor.to(device).to(torch.float32) + delta.to(torch.float32)) + .to(original_dtype) + .detach() + .cpu() + ) + return merged_tensor, True + else: + # Try ParamWrapper LoRA (lora_target_parameters) — the LoRA targets a + # parent module and this weight is a sub-parameter of that module. + if tensor.ndim >= 3: + pw_a, pw_b, param_name = _find_param_wrapper_lora( + lora_state, key, tensor_shape=tuple(tensor.shape) + ) + if pw_a is not None and pw_b is not None: + LOG.debug( + f"Merging ParamWrapper LoRA for {key} " + f"(param={param_name}): {pw_a.shape}, {pw_b.shape}" + ) + if do_nf4: + tensor = _simulate_nf4_roundtrip( + tensor, + blocksize=nf4_blocksize, + compress_statistics=nf4_double_quant, + device=device, + ) + original_dtype = tensor.dtype + delta = _build_peft_layer_and_get_delta( + pw_a.to(device), + pw_b.to(device), + lora_config_dict, + tensor.to(device), + is_param_wrapper=True, + ) + merged = ( + (tensor.to(device).to(torch.float32) + delta.to(torch.float32)) + .to(original_dtype) + .detach() + .cpu() + ) + return merged, True + + if do_nf4: + tensor = _simulate_nf4_roundtrip( + tensor, + blocksize=nf4_blocksize, + compress_statistics=nf4_double_quant, + device=device, + ) + return tensor.detach().cpu(), False + + +def _get_conversion_info(base_model_path: Path) -> tuple[Dict[str, str], list]: + """ + Load the model's config.json and check if transformers has WeightRenaming + or WeightConverter mappings for this model type. + + Returns: + - dict of {source_pattern: target_pattern} for simple renamings + - list of WeightConverter objects for fuse/unfuse operations + """ + import json as _json + + config_path = base_model_path / "config.json" + if not config_path.exists(): + return {}, [] + + try: + with open(config_path) as f: + model_config = _json.load(f) + except (OSError, _json.JSONDecodeError): + return {}, [] + + model_type = model_config.get("model_type") + if not model_type: + return {}, [] + + try: + from transformers.conversion_mapping import get_checkpoint_conversion_mapping + from transformers.core_model_loading import WeightConverter, WeightRenaming + except ImportError: + return {}, [] + + conversions = get_checkpoint_conversion_mapping(model_type) + if not conversions: + return {}, [] + + renamings = {} + weight_converters = [] + for conv in conversions: + if isinstance(conv, WeightRenaming): + # WeightRenaming stores patterns as lists internally + src_list = ( + conv.source_patterns + if isinstance(conv.source_patterns, list) + else [conv.source_patterns] + ) + tgt_list = ( + conv.target_patterns + if isinstance(conv.target_patterns, list) + else [conv.target_patterns] + ) + if len(src_list) == 1 and len(tgt_list) == 1: + renamings[src_list[0]] = tgt_list[0] + elif isinstance(conv, WeightConverter): + weight_converters.append(conv) + + return renamings, weight_converters + + +def _fuse_and_unfuse_with_merge( + shard_tensors: Dict[str, torch.Tensor], + weight_converters: list, + lora_state: Dict[str, torch.Tensor], + scale: float, + lora_config_dict: Dict, + device: str, + simulate_nf4: bool = False, + simulate_nf4_experts: bool = False, + nf4_blocksize: Optional[int] = None, + nf4_double_quant: bool = True, + use_dora: bool = False, + weight_renamings: Optional[Dict[str, str]] = None, +) -> tuple[Dict[str, torch.Tensor], int, set]: + """ + For tensors matching WeightConverter patterns (MoE expert weights): + 1. Fuse checkpoint-format tensors into runtime-format (e.g., per-expert → fused 3D) + 2. Apply NF4 roundtrip + LoRA merge on the fused tensor + 3. Unfuse back to checkpoint format for saving + + Returns: + - Updated tensor dict + - Count of merged LoRA targets + - Set of keys that were processed (fused/merged/unfused) and should be + skipped by the per-tensor merge pass to avoid double NF4 roundtrip + """ + import re + + from transformers.core_model_loading import Concatenate, MergeModulelist + + result = dict(shard_tensors) # Start with all tensors + merged_count = 0 + processed_keys: set = set() # Keys that were fuse/unfuse processed + + for converter in weight_converters: + src_patterns = ( + converter.source_patterns + if isinstance(converter.source_patterns, list) + else [converter.source_patterns] + ) + tgt_patterns = ( + converter.target_patterns + if isinstance(converter.target_patterns, list) + else [converter.target_patterns] + ) + + # Build regex for each source pattern + pattern_regexes = [] + for pat in src_patterns: + regex_str = re.escape(pat).replace(r"\.\*\.", r"\.(\d+)\.") + regex_str = ( + regex_str.rstrip(r"\$") if regex_str.endswith(r"\$") else regex_str + ) + pattern_regexes.append(re.compile(r"(.*\.)?" + regex_str + "$")) + + # Group matching keys by layer prefix and source pattern + # {layer_prefix: {pat_idx: {expert_idx: (key, tensor)}}} + layer_groups: Dict[str, Dict[int, Dict[int, tuple[str, torch.Tensor]]]] = {} + + for key in list(result.keys()): + for pat_idx, pat_regex in enumerate(pattern_regexes): + match = pat_regex.match(key) + if match: + prefix = match.group(1) or "" + # Extract expert index from the matched portion + remaining = key[len(prefix) :] + expert_match = re.search(r"\.(\d+)\.", remaining) + expert_idx = int(expert_match.group(1)) if expert_match else 0 + + layer_groups.setdefault(prefix, {}).setdefault(pat_idx, {})[ + expert_idx + ] = (key, result[key]) + break + + # Process each layer group + for prefix, pat_groups in layer_groups.items(): + # Check we have all source patterns for this layer + if not pat_groups: + continue + + # Step 1: Fuse — MergeModulelist (stack experts) per source pattern + fused_per_pattern = {} + original_keys_per_pattern: Dict[int, list[str]] = {} + num_experts = None + + for pat_idx in sorted(pat_groups.keys()): + expert_data = pat_groups[pat_idx] + sorted_indices = sorted(expert_data.keys()) + if num_experts is None: + num_experts = len(sorted_indices) + + sorted_tensors = [expert_data[idx][1] for idx in sorted_indices] + original_keys_per_pattern[pat_idx] = [ + expert_data[idx][0] for idx in sorted_indices + ] + fused_per_pattern[src_patterns[pat_idx]] = torch.stack( + sorted_tensors, dim=0 + ) + + # Apply remaining operations (Concatenate) + fused_tensor = None + has_concat = False + concat_dim = 1 # default + + for op in converter.operations: + if isinstance(op, MergeModulelist): + pass # Already handled + elif isinstance(op, Concatenate): + has_concat = True + concat_dim = op.dim + tensors_to_cat = [ + fused_per_pattern[sp] + for sp in src_patterns + if sp in fused_per_pattern + ] + if len(tensors_to_cat) > 1: + fused_tensor = torch.cat(tensors_to_cat, dim=concat_dim) + elif tensors_to_cat: + fused_tensor = tensors_to_cat[0] + + if not has_concat and len(fused_per_pattern) == 1: + fused_tensor = next(iter(fused_per_pattern.values())) + + if fused_tensor is None: + continue + + # Step 2: Build the fused key name and merge LoRA + fused_key = prefix + tgt_patterns[0] + + # Apply NF4 roundtrip on the fused tensor (matching training dynamics) + do_nf4 = _should_nf4_roundtrip( + fused_key, fused_tensor, simulate_nf4, simulate_nf4_experts + ) + if do_nf4: + fused_tensor = _simulate_nf4_roundtrip( + fused_tensor, + blocksize=nf4_blocksize, + compress_statistics=nf4_double_quant, + device=device, + ) + + # Try to find and merge LoRA weights for the fused key + lora_a, lora_b = find_lora_weights(lora_state, fused_key, weight_renamings) + if lora_a is not None and lora_b is not None: + LOG.debug( + f"Merging LoRA for fused key {fused_key}: {lora_a.shape}, {lora_b.shape}" + ) + original_dtype = fused_tensor.dtype + magnitude = ( + _find_dora_magnitude(lora_state, fused_key, weight_renamings) + if use_dora + else None + ) + delta = _build_peft_layer_and_get_delta( + lora_a.to(device), + lora_b.to(device), + lora_config_dict, + fused_tensor.to(device), + magnitude=magnitude.to(device) if magnitude is not None else None, + ) + fused_tensor = ( + ( + fused_tensor.to(device).to(torch.float32) + + delta.to(torch.float32) + ) + .to(original_dtype) + .detach() + .cpu() + ) + merged_count += 1 + + # Step 3: Save in fused format (runtime format) so that the merged + # model can be loaded directly without needing WeightConverter + # fusion during from_pretrained (which can OOM for large MoE models). + # Remove the original per-expert keys and save the fused tensor + # under the runtime key name. + for pat_idx in sorted(original_keys_per_pattern.keys()): + for ok in original_keys_per_pattern[pat_idx]: + result.pop(ok, None) + processed_keys.add(ok) + + result[fused_key] = fused_tensor.detach().cpu() + processed_keys.add(fused_key) + + return result, merged_count, processed_keys + + +def merge_lora_sharded_efficient( + base_model_path: Union[str, Path], + lora_adapter_path: Union[str, Path], + output_path: Union[str, Path], + device: str = "cpu", + safe_tensors: bool = True, + simulate_nf4: bool = False, + simulate_nf4_experts: bool = False, + nf4_blocksize: Optional[int] = None, + nf4_double_quant: bool = True, +) -> None: + """ + Memory-efficient LoRA merging that processes shards individually + without loading the full model into memory. + + Args: + simulate_nf4: Apply NF4 roundtrip to ALL weight tensors (for QLoRA) + simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors + (for quantize_moe_experts). Expert tensors are identified by having + "expert" in the key name and ndim >= 3. + """ + base_model_path = Path(base_model_path) + lora_adapter_path = Path(lora_adapter_path) + output_path = Path(output_path) + + if "/" in str(base_model_path) and not base_model_path.exists(): + base_model_path = Path(snapshot_download(str(base_model_path))) + + # Check for weight conversion requirements (transformers v5) + weight_renamings, weight_converters = _get_conversion_info(base_model_path) + if weight_renamings: + LOG.debug(f"Found {len(weight_renamings)} weight renamings for this model type") + if weight_converters: + LOG.debug( + f"Found {len(weight_converters)} weight converters (fuse/unfuse) for this model type. " + f"Will fuse→merge→unfuse within each shard." + ) + + os.makedirs(output_path, exist_ok=True) + + config_file = lora_adapter_path / "adapter_config.json" + if not config_file.exists(): + raise FileNotFoundError(f"LoRA config not found: {config_file}") + + lora_config_dict = LoraConfig.from_json_file(str(config_file)) + if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0: + raise ValueError("LoRA config 'r' must be > 0") + + use_dora = bool(lora_config_dict.get("use_dora", False)) + + unsupported_methods = [] + + # Check for AdaLoRA (Adaptive LoRA) + if lora_config_dict.get("use_adalora", False): + unsupported_methods.append("AdaLoRA (Adaptive LoRA)") + + # Check for VeRA (Vector-based Random Matrix Adaptation) + if lora_config_dict.get("use_vera", False): + unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)") + + # Check for other advanced LoRA variants by task_type + task_type = lora_config_dict.get("task_type", "") + if task_type and task_type not in [ + "CAUSAL_LM", + "SEQ_2_SEQ_LM", + "TOKEN_CLS", + "SEQ_CLS", + "QUESTION_ANS", + ]: + unsupported_methods.append(f"Task type: {task_type}") + + # Check for rank adaptation patterns (AdaLoRA indicators) + # Use .get() so empty dicts/None don't false-positive + if any( + lora_config_dict.get(key) + for key in ["rank_pattern", "alpha_pattern", "target_rank"] + ): + unsupported_methods.append("AdaLoRA (rank adaptation detected)") + + # Check for advanced initialization methods + init_lora_weights = lora_config_dict.get("init_lora_weights", "") + if init_lora_weights and init_lora_weights not in [ + "gaussian", + "loftq", + True, + False, + ]: + unsupported_methods.append(f"Advanced initialization: {init_lora_weights}") + + if unsupported_methods: + methods_str = ", ".join(unsupported_methods) + raise NotImplementedError( + f"Memory-efficient LoRA merge only supports standard LoRA. " + f"Detected unsupported methods: {methods_str}. " + f"Please use the legacy merge method for advanced LoRA variants." + ) + + use_rslora = bool(lora_config_dict.get("use_rslora", False)) + if use_rslora: + scale = float(lora_config_dict["lora_alpha"]) / math.sqrt( + float(lora_config_dict["r"]) + ) + else: + scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"]) + + LOG.debug(f"LoRA scale factor: {scale} (rslora={use_rslora})") + + if simulate_nf4: + LOG.info( + "NF4 simulation enabled: base weights will undergo quantize→dequantize " + "roundtrip before LoRA merge to match QLoRA training dynamics" + ) + + lora_file = lora_adapter_path / "adapter_model.safetensors" + if not lora_file.exists(): + lora_file = lora_adapter_path / "adapter_model.bin" + if not lora_file.exists(): + raise FileNotFoundError( + f"LoRA adapter weights not found in {lora_adapter_path}" + ) + + LOG.debug(f"Loading LoRA weights from {lora_file}") + + if lora_file.suffix == ".safetensors": + lora_state = safetensors.torch.load_file(lora_file) + else: + lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614 + LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge") + + model_shards = get_model_shards(base_model_path) + if not model_shards: + raise FileNotFoundError(f"No model shards found in {base_model_path}") + + LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}") + copy_non_model_files(base_model_path, output_path, model_shards) + + merged_count = 0 + total_tensors = 0 + # Track weight_map for index regeneration: {tensor_key: shard_filename} + weight_map: Dict[str, str] = {} + + for shard_path in tqdm(model_shards, desc="Merging shards"): + merged_tensors = {} + metadata = {} + + # Load all tensors from the shard + if shard_path.suffix == ".safetensors": + with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f: + if hasattr(f, "metadata") and f.metadata(): + metadata = f.metadata() + shard_tensors = {key: f.get_tensor(key) for key in f.keys()} + else: + shard_tensors = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + ) + + total_tensors += len(shard_tensors) + + # Step 1: Handle fused weight conversions (MoE experts) if applicable + fused_keys: set = set() + if weight_converters: + shard_tensors, fused_merged, fused_keys = _fuse_and_unfuse_with_merge( + shard_tensors, + weight_converters, + lora_state, + scale, + lora_config_dict, + device, + simulate_nf4=simulate_nf4, + simulate_nf4_experts=simulate_nf4_experts, + nf4_blocksize=nf4_blocksize, + nf4_double_quant=nf4_double_quant, + use_dora=use_dora, + weight_renamings=weight_renamings, + ) + merged_count += fused_merged + + # Step 2: Merge remaining (non-fused) tensors with LoRA + # Skip keys already processed by fuse/unfuse to avoid double NF4 roundtrip + for key, tensor in shard_tensors.items(): + if key in fused_keys: + merged_tensors[key] = tensor.detach().cpu() + continue + merged_tensor, was_merged = _merge_tensor_with_lora( + tensor, + key, + lora_state, + scale, + lora_config_dict, + device, + simulate_nf4=simulate_nf4, + simulate_nf4_experts=simulate_nf4_experts, + nf4_blocksize=nf4_blocksize, + nf4_double_quant=nf4_double_quant, + use_dora=use_dora, + weight_renamings=weight_renamings, + ) + merged_tensors[key] = merged_tensor + if was_merged: + merged_count += 1 + + output_shard_path = output_path / shard_path.name + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + + if safe_tensors: + if not str(output_shard_path).endswith(".safetensors"): + output_shard_path = output_path / (shard_path.stem + ".safetensors") + safetensors.torch.save_file( + merged_tensors, output_shard_path, metadata=metadata + ) + else: + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file( + merged_tensors, output_shard_path, metadata=metadata + ) + else: + torch.save(merged_tensors, output_shard_path) + + for tensor_key in merged_tensors: + weight_map[tensor_key] = output_shard_path.name + + del merged_tensors, shard_tensors + if device != "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # Regenerate weight-map index if the model was sharded + if len(model_shards) > 1 and weight_map: + import json as _json + + index_name = ( + "model.safetensors.index.json" + if safe_tensors + else "pytorch_model.bin.index.json" + ) + index = { + "metadata": {"total_size": total_tensors}, + "weight_map": weight_map, + } + with open(output_path / index_name, "w") as f: + _json.dump(index, f, indent=2) + LOG.debug(f"Wrote weight-map index: {index_name}") + + if merged_count == 0: + LOG.warning( + "No LoRA weights were matched to base model tensors. " + "This may indicate a key name mismatch between the checkpoint format " + "and the LoRA adapter. Consider using merge_method: legacy." + ) + LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors") diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index 983da4a37..68c458f5a 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -164,6 +164,16 @@ def patch_peft_target_parameters_matching(): from peft.utils.integrations import init_empty_weights from peft.utils.other import _get_submodules + # Mapping from unfused parameter names to their fused equivalents. + # When a model stores fused weights (e.g. gate_up_proj) but the user + # specifies unfused names (gate_proj, up_proj), we auto-expand so the + # fused parameter is also targeted. The original unfused names are kept + # in the set so that models that do NOT fuse still work. + _UNFUSED_TO_FUSED: dict[str, str] = { + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + def _patched_inject_parameters( self, peft_config, model, adapter_name, low_cpu_mem_usage ): @@ -176,10 +186,43 @@ def patch_peft_target_parameters_matching(): continue for target in original_targets: mod_path, _, param_name = target.rpartition(".") - if ( + if not ( module_name == mod_path or module_name.endswith("." + mod_path) - ) and hasattr(module, param_name): + ): + continue + + if hasattr(module, param_name): expanded.add(f"{module_name}.{param_name}") + elif param_name in _UNFUSED_TO_FUSED: + # The model uses fused weights (e.g. gate_up_proj) but the + # user specified unfused names (gate_proj / up_proj). + fused_name = _UNFUSED_TO_FUSED[param_name] + if hasattr(module, fused_name): + if fused_name not in expanded: + LOG.warning( + "target_parameter '%s' not found on %s, " + "but fused equivalent '%s' exists — adding " + "it automatically.", + param_name, + module_name, + fused_name, + ) + expanded.add(f"{module_name}.{fused_name}") + else: + LOG.warning( + "target_parameter '%s' not found on %s and no " + "fused equivalent exists either — skipping.", + param_name, + module_name, + ) + else: + LOG.warning( + "target_parameter '%s' not found on %s — skipping. " + "Check that the parameter name matches the model's " + "weight names.", + param_name, + module_name, + ) target_names_set = expanded diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 5b90fb63f..c60c548f0 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -155,6 +155,12 @@ class LoraConfig(BaseModel): ) merge_lora: bool | None = None + merge_method: Literal["legacy", "memory_efficient"] | None = Field( + default="memory_efficient", + json_schema_extra={ + "description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory." + }, + ) @model_validator(mode="before") @classmethod diff --git a/tests/utils/lora/test_merge_lora.py b/tests/utils/lora/test_merge_lora.py index 8edccafb9..e5d7f535d 100644 --- a/tests/utils/lora/test_merge_lora.py +++ b/tests/utils/lora/test_merge_lora.py @@ -1,8 +1,18 @@ +import json +import math from unittest.mock import Mock, patch +import safetensors.torch import torch from axolotl.cli.merge_lora import do_merge_lora +from axolotl.cli.utils.lora_merge import ( + _build_peft_layer_and_get_delta, + _find_param_wrapper_lora, + _merge_tensor_with_lora, + find_lora_weights, + merge_lora_sharded_efficient, +) from axolotl.utils.dict import DictDefault @@ -132,6 +142,7 @@ class TestAdapterMergeUnmerge: "torch_dtype": torch.float32, "local_rank": 0, "output_dir": str(tmp_path), + "merge_method": "legacy", } ) @@ -167,6 +178,7 @@ class TestAdapterMergeUnmerge: "save_safetensors": True, "output_dir": str(tmp_path), "local_rank": 0, + "merge_method": "legacy", } ) @@ -179,3 +191,611 @@ class TestAdapterMergeUnmerge: do_merge_lora(cfg=cfg) assert mock_load.called + + +class TestEfficientMerge: + """Test suite for memory-efficient shard-by-shard LoRA merge.""" + + def _make_adapter(self, tmp_path, r=8, alpha=16, use_dora=False, use_rslora=False): + """Create a minimal adapter directory with config + weights.""" + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + + config = { + "r": r, + "lora_alpha": alpha, + "target_modules": ["q_proj", "v_proj"], + "task_type": "CAUSAL_LM", + "bias": "none", + "use_dora": use_dora, + "use_rslora": use_rslora, + } + (adapter_dir / "adapter_config.json").write_text(json.dumps(config)) + return adapter_dir, config + + def _make_base_model(self, tmp_path, hidden=32): + """Create a minimal base model directory with one shard.""" + model_dir = tmp_path / "base_model" + model_dir.mkdir() + + weights = { + "model.layers.0.self_attn.q_proj.weight": torch.randn(hidden, hidden), + "model.layers.0.self_attn.v_proj.weight": torch.randn(hidden, hidden), + "model.embed_tokens.weight": torch.randn(100, hidden), + } + safetensors.torch.save_file(weights, model_dir / "model.safetensors") + + # Minimal config files + (model_dir / "config.json").write_text("{}") + return model_dir, weights + + def test_find_lora_weights(self): + lora_state = { + "base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn( + 8, 32 + ), + "base_model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn( + 32, 8 + ), + } + a, b = find_lora_weights(lora_state, "layers.0.self_attn.q_proj.weight") + assert a is not None and b is not None + assert a.shape == (8, 32) + + a, b = find_lora_weights(lora_state, "layers.0.self_attn.v_proj.weight") + assert a is None and b is None + + def test_merge_tensor_basic(self): + hidden = 32 + r = 8 + alpha = 16 + base = torch.randn(hidden, hidden) + lora_a = torch.randn(r, hidden) + lora_b = torch.randn(hidden, r) + scale = alpha / r + + lora_state = { + "base_model.model.layer.q_proj.lora_A.weight": lora_a, + "base_model.model.layer.q_proj.lora_B.weight": lora_b, + } + + config = {"r": r, "lora_alpha": alpha} + merged, was_merged = _merge_tensor_with_lora( + base, "layer.q_proj.weight", lora_state, scale, config, "cpu" + ) + assert was_merged + expected = base + scale * (lora_b @ lora_a) + assert torch.allclose(merged, expected, atol=1e-5) + + def test_merge_tensor_rslora_scale(self): + """RSLoRA should use alpha/sqrt(r) as scaling factor.""" + r = 16 + alpha = 32 + standard_scale = alpha / r # 2.0 + rslora_scale = alpha / math.sqrt(r) # 8.0 + + assert rslora_scale != standard_scale + assert abs(rslora_scale - 8.0) < 1e-6 + + def test_sharded_efficient_merge(self, tmp_path): + """End-to-end test of shard-by-shard merge.""" + hidden = 32 + r = 8 + alpha = 16 + + model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden) + adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha) + + # Create LoRA weights + lora_state = { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn( + r, hidden + ), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn( + hidden, r + ), + "base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn( + r, hidden + ), + "base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn( + hidden, r + ), + } + safetensors.torch.save_file( + lora_state, adapter_dir / "adapter_model.safetensors" + ) + + output_dir = tmp_path / "output" + merge_lora_sharded_efficient( + base_model_path=model_dir, + lora_adapter_path=adapter_dir, + output_path=output_dir, + device="cpu", + ) + + # Verify output exists and has merged weights + merged = safetensors.torch.load_file(output_dir / "model.safetensors") + scale = alpha / r + + q_key = "model.layers.0.self_attn.q_proj.weight" + expected_q = base_weights[q_key] + scale * ( + lora_state[f"base_model.model.{q_key[:-7]}.lora_B.weight"] + @ lora_state[f"base_model.model.{q_key[:-7]}.lora_A.weight"] + ) + assert torch.allclose(merged[q_key], expected_q, atol=1e-5) + + # Embedding should be unchanged + assert torch.equal( + merged["model.embed_tokens.weight"], + base_weights["model.embed_tokens.weight"], + ) + + def test_dora_merge(self): + """DoRA merge applies magnitude normalization via PEFT.""" + hidden = 32 + r = 8 + alpha = 16 + scale = alpha / r + + base = torch.randn(hidden, hidden) + lora_a = torch.randn(r, hidden) + lora_b = torch.randn(hidden, r) + magnitude = torch.randn(hidden).abs() + 0.1 + + lora_state = { + "base_model.model.layer.q_proj.lora_A.weight": lora_a, + "base_model.model.layer.q_proj.lora_B.weight": lora_b, + "base_model.model.layer.q_proj.lora_magnitude_vector": magnitude, + } + + config = {"r": r, "lora_alpha": alpha, "use_dora": True} + merged, was_merged = _merge_tensor_with_lora( + base, + "layer.q_proj.weight", + lora_state, + scale, + config, + "cpu", + use_dora=True, + ) + assert was_merged + + # The merge should differ from both base and base+delta (DoRA applies normalization) + delta = scale * (lora_b @ lora_a) + assert not torch.allclose(merged, base, atol=1e-3) + assert not torch.allclose(merged, base + delta, atol=1e-3) + + def test_fuse_unfuse_moe_merge(self): + """Test fuse→merge→unfuse for MoE expert weights (WeightConverter path).""" + from axolotl.cli.utils.lora_merge import _fuse_and_unfuse_with_merge + + hidden = 16 + intermediate = 32 + num_experts = 4 + r = 4 + alpha = 8 + scale = alpha / r + + # Simulate checkpoint format: per-expert separate tensors + shard_tensors = {} + for i in range(num_experts): + shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = ( + torch.randn(intermediate, hidden) + ) + shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = ( + torch.randn(intermediate, hidden) + ) + shard_tensors[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = ( + torch.randn(hidden, intermediate) + ) + shard_tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn( + hidden, hidden + ) + + # LoRA targets the fused key (runtime format) + lora_state = { + "base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight": torch.randn( + r, hidden + ), + "base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight": torch.randn( + intermediate * 2, r + ), + "base_model.model.model.layers.0.mlp.experts.down_proj.lora_A.weight": torch.randn( + r, intermediate + ), + "base_model.model.model.layers.0.mlp.experts.down_proj.lora_B.weight": torch.randn( + hidden, r + ), + } + + # Build converters matching qwen2_moe pattern + from transformers.core_model_loading import ( + Concatenate, + MergeModulelist, + WeightConverter, + ) + + converters = [ + WeightConverter( + source_patterns=[ + "mlp.experts.*.gate_proj.weight", + "mlp.experts.*.up_proj.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="mlp.experts.*.down_proj.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ] + + config = {"r": r, "lora_alpha": alpha} + result, merged_count, processed_keys = _fuse_and_unfuse_with_merge( + shard_tensors, converters, lora_state, scale, config, "cpu" + ) + + # Should have merged 2 LoRA targets (gate_up_proj and down_proj) + assert merged_count == 2 + + # Processed keys include original per-expert keys (removed) + fused keys (added) + assert len(processed_keys) > 0 + + # Output should be in fused format (runtime keys) + assert "model.layers.0.mlp.experts.gate_up_proj" in result + assert "model.layers.0.mlp.experts.down_proj" in result + + # Per-expert keys should be removed + for i in range(num_experts): + assert f"model.layers.0.mlp.experts.{i}.gate_proj.weight" not in result + + # Non-expert tensor should be passed through + assert "model.layers.0.self_attn.q_proj.weight" in result + + # Verify fused tensors are 3D (stacked experts) + gate_up = result["model.layers.0.mlp.experts.gate_up_proj"] + assert gate_up.ndim == 3 + assert gate_up.shape[0] == num_experts # [num_experts, intermediate*2, hidden] + + # Verify the fused LoRA delta was applied correctly + # Reconstruct the fused base (stack per-expert, concat gate+up) + gate_stack = torch.stack( + [ + shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] + for i in range(num_experts) + ] + ) + up_stack = torch.stack( + [ + shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] + for i in range(num_experts) + ] + ) + base_fused = torch.cat([gate_stack, up_stack], dim=1) + lora_a = lora_state[ + "base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight" + ] + lora_b = lora_state[ + "base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight" + ] + expected_fused = base_fused + scale * (lora_b @ lora_a) + assert torch.allclose(gate_up, expected_fused, atol=1e-5) + + def test_param_wrapper_merge_math(self): + """ParamWrapper merge via PEFT's get_delta_weight matches manual einsum.""" + num_experts = 4 + r = 2 + in_features = 8 + out_features = 4 + alpha = 4 + + base = torch.randn(num_experts, in_features, out_features) + lora_a = torch.randn(r * num_experts, in_features) + lora_b = torch.randn(out_features, r * num_experts) + + config = {"r": r, "lora_alpha": alpha} + delta = _build_peft_layer_and_get_delta( + lora_a, lora_b, config, base, is_param_wrapper=True + ) + assert delta.shape == base.shape + + merged = base + delta + + # Verify against manual einsum + scale = alpha / r + wa = lora_a.reshape(num_experts, r, in_features) + wb = lora_b.reshape(out_features, r, num_experts) + manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale + for e in range(num_experts): + assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), ( + f"Expert {e} mismatch" + ) + + def test_param_wrapper_nesting_dim_filter(self): + """_find_param_wrapper_lora skips wrong-dimension LoRA at outer level.""" + num_experts = 4 + r = 2 + + # Outer LoRA (gate_up_proj): A=[r*E, 8], B=[16, r*E] + # Inner LoRA (down_proj via base_layer): A=[r*E, 16], B=[8, r*E] + lora_state = { + "base_model.model.mod.experts.lora_A.weight": torch.randn( + r * num_experts, 8 + ), + "base_model.model.mod.experts.lora_B.weight": torch.randn( + 16, r * num_experts + ), + "base_model.model.mod.experts.base_layer.lora_A.weight": torch.randn( + r * num_experts, 16 + ), + "base_model.model.mod.experts.base_layer.lora_B.weight": torch.randn( + 8, r * num_experts + ), + } + + # gate_up_proj shape [4, 8, 16] — should match outer LoRA + a, b, name = _find_param_wrapper_lora( + lora_state, "mod.experts.gate_up_proj", tensor_shape=(4, 8, 16) + ) + assert a is not None and name == "gate_up_proj" + assert a.shape == (r * num_experts, 8) # outer + + # down_proj shape [4, 16, 8] — outer dims don't match, should find inner + a, b, name = _find_param_wrapper_lora( + lora_state, "mod.experts.down_proj", tensor_shape=(4, 16, 8) + ) + assert a is not None and name == "down_proj" + assert a.shape == (r * num_experts, 16) # inner (base_layer) + + # shape that matches neither — should return None + a, b, name = _find_param_wrapper_lora( + lora_state, "mod.experts.other", tensor_shape=(4, 99, 99) + ) + assert a is None + + def test_find_lora_weights_with_renamings(self): + """Weight renamings let checkpoint keys match LoRA keys.""" + lora_state = { + "base_model.model.layers.0.mlp.fc1.lora_A.weight": torch.randn(8, 32), + "base_model.model.layers.0.mlp.fc1.lora_B.weight": torch.randn(32, 8), + } + # Direct lookup fails (checkpoint has "ff0", LoRA has "fc1") + a, b = find_lora_weights(lora_state, "layers.0.mlp.ff0.weight") + assert a is None + + # With renaming ff0 → fc1, it should match + a, b = find_lora_weights( + lora_state, "layers.0.mlp.ff0.weight", weight_renamings={"ff0": "fc1"} + ) + assert a is not None + assert a.shape == (8, 32) + + def test_unmatched_tensors_pass_through(self): + """Tensors with no matching LoRA are returned unchanged.""" + lora_state = { + "base_model.model.layer.q_proj.lora_A.weight": torch.randn(8, 32), + "base_model.model.layer.q_proj.lora_B.weight": torch.randn(32, 8), + } + + # 1D tensor (layernorm) — never matched + ln = torch.randn(32) + merged, was_merged = _merge_tensor_with_lora( + ln, "layer.norm.weight", lora_state, 2.0, {}, "cpu" + ) + assert not was_merged + assert torch.equal(merged, ln) + + # 2D tensor with no matching key + unrelated = torch.randn(64, 32) + merged, was_merged = _merge_tensor_with_lora( + unrelated, "layer.other_proj.weight", lora_state, 2.0, {}, "cpu" + ) + assert not was_merged + assert torch.equal(merged, unrelated) + + def test_fan_in_fan_out_transpose(self): + """fan_in_fan_out config transposes the LoRA delta.""" + hidden = 16 + r = 4 + alpha = 4 # scale = 1.0 + + base = torch.randn(hidden, hidden) + lora_a = torch.randn(r, hidden) + lora_b = torch.randn(hidden, r) + + lora_state = { + "base_model.model.layer.proj.lora_A.weight": lora_a, + "base_model.model.layer.proj.lora_B.weight": lora_b, + } + + config_normal = {"r": r, "lora_alpha": alpha} + config_fif = {"r": r, "lora_alpha": alpha, "fan_in_fan_out": True} + + merged_normal, _ = _merge_tensor_with_lora( + base, "layer.proj.weight", lora_state, 1.0, config_normal, "cpu" + ) + merged_fif, _ = _merge_tensor_with_lora( + base, "layer.proj.weight", lora_state, 1.0, config_fif, "cpu" + ) + + delta = (alpha / r) * (lora_b @ lora_a) + assert torch.allclose(merged_normal, base + delta, atol=1e-5) + assert torch.allclose(merged_fif, base + delta.T, atol=1e-5) + assert not torch.allclose(merged_normal, merged_fif, atol=1e-5) + + def test_rslora_end_to_end(self, tmp_path): + """RSLoRA adapter uses alpha/sqrt(r) scaling in sharded merge.""" + hidden = 16 + r = 16 + alpha = 32 + + model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden) + adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_rslora=True) + + lora_a = torch.randn(r, hidden) + lora_b = torch.randn(hidden, r) + lora_state = { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a, + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b, + } + safetensors.torch.save_file( + lora_state, adapter_dir / "adapter_model.safetensors" + ) + + output_dir = tmp_path / "output" + merge_lora_sharded_efficient( + base_model_path=model_dir, + lora_adapter_path=adapter_dir, + output_path=output_dir, + device="cpu", + ) + + merged = safetensors.torch.load_file(output_dir / "model.safetensors") + rslora_scale = alpha / math.sqrt(r) # 8.0, not 2.0 + q_key = "model.layers.0.self_attn.q_proj.weight" + expected = base_weights[q_key] + rslora_scale * (lora_b @ lora_a) + assert torch.allclose(merged[q_key], expected, atol=1e-5) + + # Confirm it differs from standard scale + wrong_scale = alpha / r # 2.0 + wrong_expected = base_weights[q_key] + wrong_scale * (lora_b @ lora_a) + assert not torch.allclose(merged[q_key], wrong_expected, atol=1e-3) + + def test_multi_shard_index_json(self, tmp_path): + """Multi-shard merge generates a correct weight-map index.""" + hidden = 16 + r = 4 + alpha = 8 + + model_dir = tmp_path / "base_model" + model_dir.mkdir() + (model_dir / "config.json").write_text("{}") + + # Create 2 shards + shard1 = {"model.layers.0.weight": torch.randn(hidden, hidden)} + shard2 = {"model.layers.1.weight": torch.randn(hidden, hidden)} + safetensors.torch.save_file( + shard1, model_dir / "model-00001-of-00002.safetensors" + ) + safetensors.torch.save_file( + shard2, model_dir / "model-00002-of-00002.safetensors" + ) + + # Write a base model index (will be skipped by copy_non_model_files) + base_index = { + "metadata": {}, + "weight_map": { + "model.layers.0.weight": "model-00001-of-00002.safetensors", + "model.layers.1.weight": "model-00002-of-00002.safetensors", + }, + } + (model_dir / "model.safetensors.index.json").write_text(json.dumps(base_index)) + + adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha) + safetensors.torch.save_file({}, adapter_dir / "adapter_model.safetensors") + + output_dir = tmp_path / "output" + merge_lora_sharded_efficient( + base_model_path=model_dir, + lora_adapter_path=adapter_dir, + output_path=output_dir, + device="cpu", + ) + + # Verify index was generated + index_path = output_dir / "model.safetensors.index.json" + assert index_path.exists() + with open(index_path) as f: + idx = json.load(f) + + assert "weight_map" in idx + assert len(idx["weight_map"]) == 2 + # Each key should map to a shard that exists + for _key, shard_name in idx["weight_map"].items(): + assert (output_dir / shard_name).exists(), f"Missing shard: {shard_name}" + + def test_dora_end_to_end(self, tmp_path): + """DoRA merge through the full sharded merge pipeline.""" + hidden = 16 + r = 4 + alpha = 8 + + model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden) + adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_dora=True) + + lora_a = torch.randn(r, hidden) + lora_b = torch.randn(hidden, r) + magnitude = torch.randn(hidden).abs() + 0.1 + lora_state = { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a, + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b, + "base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector": magnitude, + } + safetensors.torch.save_file( + lora_state, adapter_dir / "adapter_model.safetensors" + ) + + output_dir = tmp_path / "output" + merge_lora_sharded_efficient( + base_model_path=model_dir, + lora_adapter_path=adapter_dir, + output_path=output_dir, + device="cpu", + ) + + merged = safetensors.torch.load_file(output_dir / "model.safetensors") + q_key = "model.layers.0.self_attn.q_proj.weight" + + # Use PEFT's own get_delta_weight as the reference + delta = _build_peft_layer_and_get_delta( + lora_a, + lora_b, + {"r": r, "lora_alpha": alpha, "use_dora": True}, + base_weights[q_key], + magnitude=magnitude, + ) + expected = base_weights[q_key] + delta + assert torch.allclose(merged[q_key], expected, atol=1e-5) + + # Verify it differs from standard (non-DoRA) merge + standard_delta = _build_peft_layer_and_get_delta( + lora_a, + lora_b, + {"r": r, "lora_alpha": alpha}, + base_weights[q_key], + ) + assert not torch.allclose(delta, standard_delta, atol=1e-3) + + # v_proj has no LoRA weights — should be unchanged + v_key = "model.layers.0.self_attn.v_proj.weight" + assert torch.equal(merged[v_key], base_weights[v_key]), ( + "v_proj should be unchanged (no LoRA weights for it)" + ) + + def test_dora_missing_magnitude_falls_back(self): + """DoRA without magnitude vector falls back to standard LoRA merge.""" + hidden = 16 + r = 4 + alpha = 8 + scale = alpha / r + + base = torch.randn(hidden, hidden) + lora_a = torch.randn(r, hidden) + lora_b = torch.randn(hidden, r) + + # No magnitude vector in lora_state + lora_state = { + "base_model.model.layer.proj.lora_A.weight": lora_a, + "base_model.model.layer.proj.lora_B.weight": lora_b, + } + + config = {"r": r, "lora_alpha": alpha, "use_dora": True} + merged, was_merged = _merge_tensor_with_lora( + base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True + ) + assert was_merged + # No magnitude vector → PEFT creates DoRA layer but with default magnitude, + # which produces a result different from plain W + scale * B @ A. + # Just verify it was merged (not unchanged). + assert not torch.equal(merged, base)