feat:merge-lora iterate through bins without loading (#3095)

* merge_method added

* merge_efficient core implement

* Update src/axolotl/cli/merge_lora.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/lora_merge_efficient.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* standard to leagcy + rstrip + try/except for do_merge_lora_efficient(cfg=cfg)

* fix: 'dict' object has no attribute 'lora_alpha'

* into -> debug

* lint

* lint2

* moved everythign to cpu + peformance improvments

* lint

* Update src/axolotl/cli/merge_lora.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/cli/merge_lora.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* string handeling +  try except remove

* merge_method -> merge_lora_methods

* remove duplicate cal + safetensor + move to lora_merge.py

* lint

* handle quant-dequant, handle experts

* fix parameter merging and prefer peft's native merge logic per module

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
VED
2026-03-25 18:11:32 +05:30
committed by GitHub
parent ff0f67c730
commit b55706b9f6
5 changed files with 1735 additions and 5 deletions

View File

@@ -1,8 +1,18 @@
import json
import math
from unittest.mock import Mock, patch
import safetensors.torch
import torch
from axolotl.cli.merge_lora import do_merge_lora
from axolotl.cli.utils.lora_merge import (
_build_peft_layer_and_get_delta,
_find_param_wrapper_lora,
_merge_tensor_with_lora,
find_lora_weights,
merge_lora_sharded_efficient,
)
from axolotl.utils.dict import DictDefault
@@ -132,6 +142,7 @@ class TestAdapterMergeUnmerge:
"torch_dtype": torch.float32,
"local_rank": 0,
"output_dir": str(tmp_path),
"merge_method": "legacy",
}
)
@@ -167,6 +178,7 @@ class TestAdapterMergeUnmerge:
"save_safetensors": True,
"output_dir": str(tmp_path),
"local_rank": 0,
"merge_method": "legacy",
}
)
@@ -179,3 +191,611 @@ class TestAdapterMergeUnmerge:
do_merge_lora(cfg=cfg)
assert mock_load.called
class TestEfficientMerge:
"""Test suite for memory-efficient shard-by-shard LoRA merge."""
def _make_adapter(self, tmp_path, r=8, alpha=16, use_dora=False, use_rslora=False):
"""Create a minimal adapter directory with config + weights."""
adapter_dir = tmp_path / "adapter"
adapter_dir.mkdir()
config = {
"r": r,
"lora_alpha": alpha,
"target_modules": ["q_proj", "v_proj"],
"task_type": "CAUSAL_LM",
"bias": "none",
"use_dora": use_dora,
"use_rslora": use_rslora,
}
(adapter_dir / "adapter_config.json").write_text(json.dumps(config))
return adapter_dir, config
def _make_base_model(self, tmp_path, hidden=32):
"""Create a minimal base model directory with one shard."""
model_dir = tmp_path / "base_model"
model_dir.mkdir()
weights = {
"model.layers.0.self_attn.q_proj.weight": torch.randn(hidden, hidden),
"model.layers.0.self_attn.v_proj.weight": torch.randn(hidden, hidden),
"model.embed_tokens.weight": torch.randn(100, hidden),
}
safetensors.torch.save_file(weights, model_dir / "model.safetensors")
# Minimal config files
(model_dir / "config.json").write_text("{}")
return model_dir, weights
def test_find_lora_weights(self):
lora_state = {
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
8, 32
),
"base_model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
32, 8
),
}
a, b = find_lora_weights(lora_state, "layers.0.self_attn.q_proj.weight")
assert a is not None and b is not None
assert a.shape == (8, 32)
a, b = find_lora_weights(lora_state, "layers.0.self_attn.v_proj.weight")
assert a is None and b is None
def test_merge_tensor_basic(self):
hidden = 32
r = 8
alpha = 16
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
scale = alpha / r
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
}
config = {"r": r, "lora_alpha": alpha}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.q_proj.weight", lora_state, scale, config, "cpu"
)
assert was_merged
expected = base + scale * (lora_b @ lora_a)
assert torch.allclose(merged, expected, atol=1e-5)
def test_merge_tensor_rslora_scale(self):
"""RSLoRA should use alpha/sqrt(r) as scaling factor."""
r = 16
alpha = 32
standard_scale = alpha / r # 2.0
rslora_scale = alpha / math.sqrt(r) # 8.0
assert rslora_scale != standard_scale
assert abs(rslora_scale - 8.0) < 1e-6
def test_sharded_efficient_merge(self, tmp_path):
"""End-to-end test of shard-by-shard merge."""
hidden = 32
r = 8
alpha = 16
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
# Create LoRA weights
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
hidden, r
),
"base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn(
hidden, r
),
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
# Verify output exists and has merged weights
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
scale = alpha / r
q_key = "model.layers.0.self_attn.q_proj.weight"
expected_q = base_weights[q_key] + scale * (
lora_state[f"base_model.model.{q_key[:-7]}.lora_B.weight"]
@ lora_state[f"base_model.model.{q_key[:-7]}.lora_A.weight"]
)
assert torch.allclose(merged[q_key], expected_q, atol=1e-5)
# Embedding should be unchanged
assert torch.equal(
merged["model.embed_tokens.weight"],
base_weights["model.embed_tokens.weight"],
)
def test_dora_merge(self):
"""DoRA merge applies magnitude normalization via PEFT."""
hidden = 32
r = 8
alpha = 16
scale = alpha / r
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
magnitude = torch.randn(hidden).abs() + 0.1
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
"base_model.model.layer.q_proj.lora_magnitude_vector": magnitude,
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base,
"layer.q_proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)
assert was_merged
# The merge should differ from both base and base+delta (DoRA applies normalization)
delta = scale * (lora_b @ lora_a)
assert not torch.allclose(merged, base, atol=1e-3)
assert not torch.allclose(merged, base + delta, atol=1e-3)
def test_fuse_unfuse_moe_merge(self):
"""Test fuse→merge→unfuse for MoE expert weights (WeightConverter path)."""
from axolotl.cli.utils.lora_merge import _fuse_and_unfuse_with_merge
hidden = 16
intermediate = 32
num_experts = 4
r = 4
alpha = 8
scale = alpha / r
# Simulate checkpoint format: per-expert separate tensors
shard_tensors = {}
for i in range(num_experts):
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = (
torch.randn(intermediate, hidden)
)
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = (
torch.randn(intermediate, hidden)
)
shard_tensors[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = (
torch.randn(hidden, intermediate)
)
shard_tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn(
hidden, hidden
)
# LoRA targets the fused key (runtime format)
lora_state = {
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight": torch.randn(
intermediate * 2, r
),
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_A.weight": torch.randn(
r, intermediate
),
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_B.weight": torch.randn(
hidden, r
),
}
# Build converters matching qwen2_moe pattern
from transformers.core_model_loading import (
Concatenate,
MergeModulelist,
WeightConverter,
)
converters = [
WeightConverter(
source_patterns=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
]
config = {"r": r, "lora_alpha": alpha}
result, merged_count, processed_keys = _fuse_and_unfuse_with_merge(
shard_tensors, converters, lora_state, scale, config, "cpu"
)
# Should have merged 2 LoRA targets (gate_up_proj and down_proj)
assert merged_count == 2
# Processed keys include original per-expert keys (removed) + fused keys (added)
assert len(processed_keys) > 0
# Output should be in fused format (runtime keys)
assert "model.layers.0.mlp.experts.gate_up_proj" in result
assert "model.layers.0.mlp.experts.down_proj" in result
# Per-expert keys should be removed
for i in range(num_experts):
assert f"model.layers.0.mlp.experts.{i}.gate_proj.weight" not in result
# Non-expert tensor should be passed through
assert "model.layers.0.self_attn.q_proj.weight" in result
# Verify fused tensors are 3D (stacked experts)
gate_up = result["model.layers.0.mlp.experts.gate_up_proj"]
assert gate_up.ndim == 3
assert gate_up.shape[0] == num_experts # [num_experts, intermediate*2, hidden]
# Verify the fused LoRA delta was applied correctly
# Reconstruct the fused base (stack per-expert, concat gate+up)
gate_stack = torch.stack(
[
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"]
for i in range(num_experts)
]
)
up_stack = torch.stack(
[
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"]
for i in range(num_experts)
]
)
base_fused = torch.cat([gate_stack, up_stack], dim=1)
lora_a = lora_state[
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight"
]
lora_b = lora_state[
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight"
]
expected_fused = base_fused + scale * (lora_b @ lora_a)
assert torch.allclose(gate_up, expected_fused, atol=1e-5)
def test_param_wrapper_merge_math(self):
"""ParamWrapper merge via PEFT's get_delta_weight matches manual einsum."""
num_experts = 4
r = 2
in_features = 8
out_features = 4
alpha = 4
base = torch.randn(num_experts, in_features, out_features)
lora_a = torch.randn(r * num_experts, in_features)
lora_b = torch.randn(out_features, r * num_experts)
config = {"r": r, "lora_alpha": alpha}
delta = _build_peft_layer_and_get_delta(
lora_a, lora_b, config, base, is_param_wrapper=True
)
assert delta.shape == base.shape
merged = base + delta
# Verify against manual einsum
scale = alpha / r
wa = lora_a.reshape(num_experts, r, in_features)
wb = lora_b.reshape(out_features, r, num_experts)
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
for e in range(num_experts):
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
f"Expert {e} mismatch"
)
def test_param_wrapper_nesting_dim_filter(self):
"""_find_param_wrapper_lora skips wrong-dimension LoRA at outer level."""
num_experts = 4
r = 2
# Outer LoRA (gate_up_proj): A=[r*E, 8], B=[16, r*E]
# Inner LoRA (down_proj via base_layer): A=[r*E, 16], B=[8, r*E]
lora_state = {
"base_model.model.mod.experts.lora_A.weight": torch.randn(
r * num_experts, 8
),
"base_model.model.mod.experts.lora_B.weight": torch.randn(
16, r * num_experts
),
"base_model.model.mod.experts.base_layer.lora_A.weight": torch.randn(
r * num_experts, 16
),
"base_model.model.mod.experts.base_layer.lora_B.weight": torch.randn(
8, r * num_experts
),
}
# gate_up_proj shape [4, 8, 16] — should match outer LoRA
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.gate_up_proj", tensor_shape=(4, 8, 16)
)
assert a is not None and name == "gate_up_proj"
assert a.shape == (r * num_experts, 8) # outer
# down_proj shape [4, 16, 8] — outer dims don't match, should find inner
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.down_proj", tensor_shape=(4, 16, 8)
)
assert a is not None and name == "down_proj"
assert a.shape == (r * num_experts, 16) # inner (base_layer)
# shape that matches neither — should return None
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.other", tensor_shape=(4, 99, 99)
)
assert a is None
def test_find_lora_weights_with_renamings(self):
"""Weight renamings let checkpoint keys match LoRA keys."""
lora_state = {
"base_model.model.layers.0.mlp.fc1.lora_A.weight": torch.randn(8, 32),
"base_model.model.layers.0.mlp.fc1.lora_B.weight": torch.randn(32, 8),
}
# Direct lookup fails (checkpoint has "ff0", LoRA has "fc1")
a, b = find_lora_weights(lora_state, "layers.0.mlp.ff0.weight")
assert a is None
# With renaming ff0 → fc1, it should match
a, b = find_lora_weights(
lora_state, "layers.0.mlp.ff0.weight", weight_renamings={"ff0": "fc1"}
)
assert a is not None
assert a.shape == (8, 32)
def test_unmatched_tensors_pass_through(self):
"""Tensors with no matching LoRA are returned unchanged."""
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": torch.randn(8, 32),
"base_model.model.layer.q_proj.lora_B.weight": torch.randn(32, 8),
}
# 1D tensor (layernorm) — never matched
ln = torch.randn(32)
merged, was_merged = _merge_tensor_with_lora(
ln, "layer.norm.weight", lora_state, 2.0, {}, "cpu"
)
assert not was_merged
assert torch.equal(merged, ln)
# 2D tensor with no matching key
unrelated = torch.randn(64, 32)
merged, was_merged = _merge_tensor_with_lora(
unrelated, "layer.other_proj.weight", lora_state, 2.0, {}, "cpu"
)
assert not was_merged
assert torch.equal(merged, unrelated)
def test_fan_in_fan_out_transpose(self):
"""fan_in_fan_out config transposes the LoRA delta."""
hidden = 16
r = 4
alpha = 4 # scale = 1.0
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
lora_state = {
"base_model.model.layer.proj.lora_A.weight": lora_a,
"base_model.model.layer.proj.lora_B.weight": lora_b,
}
config_normal = {"r": r, "lora_alpha": alpha}
config_fif = {"r": r, "lora_alpha": alpha, "fan_in_fan_out": True}
merged_normal, _ = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, 1.0, config_normal, "cpu"
)
merged_fif, _ = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, 1.0, config_fif, "cpu"
)
delta = (alpha / r) * (lora_b @ lora_a)
assert torch.allclose(merged_normal, base + delta, atol=1e-5)
assert torch.allclose(merged_fif, base + delta.T, atol=1e-5)
assert not torch.allclose(merged_normal, merged_fif, atol=1e-5)
def test_rslora_end_to_end(self, tmp_path):
"""RSLoRA adapter uses alpha/sqrt(r) scaling in sharded merge."""
hidden = 16
r = 16
alpha = 32
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_rslora=True)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
rslora_scale = alpha / math.sqrt(r) # 8.0, not 2.0
q_key = "model.layers.0.self_attn.q_proj.weight"
expected = base_weights[q_key] + rslora_scale * (lora_b @ lora_a)
assert torch.allclose(merged[q_key], expected, atol=1e-5)
# Confirm it differs from standard scale
wrong_scale = alpha / r # 2.0
wrong_expected = base_weights[q_key] + wrong_scale * (lora_b @ lora_a)
assert not torch.allclose(merged[q_key], wrong_expected, atol=1e-3)
def test_multi_shard_index_json(self, tmp_path):
"""Multi-shard merge generates a correct weight-map index."""
hidden = 16
r = 4
alpha = 8
model_dir = tmp_path / "base_model"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}")
# Create 2 shards
shard1 = {"model.layers.0.weight": torch.randn(hidden, hidden)}
shard2 = {"model.layers.1.weight": torch.randn(hidden, hidden)}
safetensors.torch.save_file(
shard1, model_dir / "model-00001-of-00002.safetensors"
)
safetensors.torch.save_file(
shard2, model_dir / "model-00002-of-00002.safetensors"
)
# Write a base model index (will be skipped by copy_non_model_files)
base_index = {
"metadata": {},
"weight_map": {
"model.layers.0.weight": "model-00001-of-00002.safetensors",
"model.layers.1.weight": "model-00002-of-00002.safetensors",
},
}
(model_dir / "model.safetensors.index.json").write_text(json.dumps(base_index))
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
safetensors.torch.save_file({}, adapter_dir / "adapter_model.safetensors")
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
# Verify index was generated
index_path = output_dir / "model.safetensors.index.json"
assert index_path.exists()
with open(index_path) as f:
idx = json.load(f)
assert "weight_map" in idx
assert len(idx["weight_map"]) == 2
# Each key should map to a shard that exists
for _key, shard_name in idx["weight_map"].items():
assert (output_dir / shard_name).exists(), f"Missing shard: {shard_name}"
def test_dora_end_to_end(self, tmp_path):
"""DoRA merge through the full sharded merge pipeline."""
hidden = 16
r = 4
alpha = 8
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_dora=True)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
magnitude = torch.randn(hidden).abs() + 0.1
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
"base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector": magnitude,
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
q_key = "model.layers.0.self_attn.q_proj.weight"
# Use PEFT's own get_delta_weight as the reference
delta = _build_peft_layer_and_get_delta(
lora_a,
lora_b,
{"r": r, "lora_alpha": alpha, "use_dora": True},
base_weights[q_key],
magnitude=magnitude,
)
expected = base_weights[q_key] + delta
assert torch.allclose(merged[q_key], expected, atol=1e-5)
# Verify it differs from standard (non-DoRA) merge
standard_delta = _build_peft_layer_and_get_delta(
lora_a,
lora_b,
{"r": r, "lora_alpha": alpha},
base_weights[q_key],
)
assert not torch.allclose(delta, standard_delta, atol=1e-3)
# v_proj has no LoRA weights — should be unchanged
v_key = "model.layers.0.self_attn.v_proj.weight"
assert torch.equal(merged[v_key], base_weights[v_key]), (
"v_proj should be unchanged (no LoRA weights for it)"
)
def test_dora_missing_magnitude_falls_back(self):
"""DoRA without magnitude vector falls back to standard LoRA merge."""
hidden = 16
r = 4
alpha = 8
scale = alpha / r
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
# No magnitude vector in lora_state
lora_state = {
"base_model.model.layer.proj.lora_A.weight": lora_a,
"base_model.model.layer.proj.lora_B.weight": lora_b,
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
)
assert was_merged
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
# which produces a result different from plain W + scale * B @ A.
# Just verify it was merged (not unchanged).
assert not torch.equal(merged, base)