Files
axolotl/tests/utils/lora/test_merge_lora.py
2026-04-20 23:09:47 +00:00

806 lines
30 KiB
Python

import json
import math
from unittest.mock import Mock, patch
import pytest
import safetensors.torch
import torch
from axolotl.cli.merge_lora import do_merge_lora
from axolotl.cli.utils.lora_merge import (
_build_peft_layer_and_get_delta,
_find_param_wrapper_lora,
_merge_tensor_with_lora,
find_lora_weights,
merge_lora_sharded_efficient,
)
from axolotl.utils.dict import DictDefault
class TestAdapterMergeUnmerge:
"""Test suite for LoRA adapter merging/unmerging functionality"""
def setup_method(self):
self.dtype = torch.float32
self.device = torch.device("cpu")
def create_mock_base_model(self, vocab_size=1000, hidden_size=256):
"""Create a mock base model with linear layers"""
mock_model = Mock()
mock_model.config = Mock()
mock_model.config.vocab_size = vocab_size
mock_model.config.hidden_size = hidden_size
mock_model.q_proj = Mock()
mock_model.q_proj.weight = torch.randn(
hidden_size, hidden_size, dtype=self.dtype
)
mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
mock_model.v_proj = Mock()
mock_model.v_proj.weight = torch.randn(
hidden_size, hidden_size, dtype=self.dtype
)
mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
return mock_model
def create_mock_lora_model(self, base_model, r=8, alpha=16):
"""Create a mock LoRA model wrapping the base model"""
mock_lora_model = Mock()
mock_lora_model.base_model = base_model
mock_lora_model.merge_and_unload = None
mock_lora_model.to = Mock(return_value=mock_lora_model)
mock_lora_model.generation_config = Mock()
mock_lora_model.config = Mock()
self.original_q_weight = base_model.q_proj.weight.clone()
self.original_v_weight = base_model.v_proj.weight.clone()
mock_lora_model.peft_config = {"default": Mock()}
mock_lora_model.peft_config["default"].r = r
mock_lora_model.peft_config["default"].lora_alpha = alpha
self.lora_A_q = torch.randn(
r, base_model.q_proj.weight.shape[1], dtype=self.dtype
)
self.lora_B_q = torch.randn(
base_model.q_proj.weight.shape[0], r, dtype=self.dtype
)
self.lora_A_v = torch.randn(
r, base_model.v_proj.weight.shape[1], dtype=self.dtype
)
self.lora_B_v = torch.randn(
base_model.v_proj.weight.shape[0], r, dtype=self.dtype
)
self.scaling = alpha / r
def mock_merge_and_unload(progressbar=False):
"""Simulate the actual merge operation"""
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
base_model.q_proj.weight = self.original_q_weight + delta_q
base_model.v_proj.weight = self.original_v_weight + delta_v
return base_model
mock_lora_model.merge_and_unload = mock_merge_and_unload
return mock_lora_model
def test_basic_lora_merge_unmerge_cycle(self):
"""Test: original_weights -> merge -> unmerge -> should equal original_weights"""
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
original_q_weight = self.original_q_weight.clone()
original_v_weight = self.original_v_weight.clone()
merged_model = lora_model.merge_and_unload()
assert not torch.equal(merged_model.q_proj.weight, original_q_weight)
assert not torch.equal(merged_model.v_proj.weight, original_v_weight)
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
unmerged_q_weight = merged_model.q_proj.weight - delta_q
unmerged_v_weight = merged_model.v_proj.weight - delta_v
assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6)
assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6)
def test_merge_weight_calculation_accuracy(self):
"""Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)"""
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32)
expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
expected_merged_q = self.original_q_weight + expected_delta_q
merged_model = lora_model.merge_and_unload()
assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6)
@patch("axolotl.cli.merge_lora.load_model_and_tokenizer")
def test_cli_do_merge_functionality(self, mock_load_model, tmp_path):
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
tokenizer = Mock()
processor = None
mock_load_model.return_value = (lora_model, tokenizer, processor)
cfg = DictDefault(
{
"save_safetensors": True,
"torch_dtype": torch.float32,
"local_rank": 0,
"output_dir": str(tmp_path),
"merge_method": "legacy",
}
)
with (
patch("pathlib.Path.mkdir"),
patch.object(base_model, "save_pretrained") as mock_save_model,
patch.object(tokenizer, "save_pretrained") as mock_save_tokenizer,
):
do_merge_lora(cfg=cfg)
mock_save_model.assert_called_once()
mock_save_tokenizer.assert_called_once()
def test_quantized_model_merge_compatibility(self):
"""Test 4-bit/8-bit model merging scenarios"""
base_model = self.create_mock_base_model()
# Mock quantized weights
base_model.q_proj.weight.quant_state = Mock()
base_model.q_proj.weight.quant_state.dtype = torch.uint8
lora_model = self.create_mock_lora_model(base_model)
merged_model = lora_model.merge_and_unload()
assert merged_model is not None
@patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": ""})
def test_memory_efficient_merge_with_cpu_offload(self, tmp_path):
"""Test lora_on_cpu configuration during merge"""
cfg = DictDefault(
{
"lora_on_cpu": True,
"save_safetensors": True,
"output_dir": str(tmp_path),
"local_rank": 0,
"merge_method": "legacy",
}
)
with patch("axolotl.cli.merge_lora.load_model_and_tokenizer") as mock_load:
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
mock_load.return_value = (lora_model, Mock(), None)
with patch("pathlib.Path.mkdir"), patch("torch.save"):
do_merge_lora(cfg=cfg)
assert mock_load.called
class TestEfficientMerge:
"""Test suite for memory-efficient shard-by-shard LoRA merge."""
def _make_adapter(self, tmp_path, r=8, alpha=16, use_dora=False, use_rslora=False):
"""Create a minimal adapter directory with config + weights."""
adapter_dir = tmp_path / "adapter"
adapter_dir.mkdir()
config = {
"r": r,
"lora_alpha": alpha,
"target_modules": ["q_proj", "v_proj"],
"task_type": "CAUSAL_LM",
"bias": "none",
"use_dora": use_dora,
"use_rslora": use_rslora,
}
(adapter_dir / "adapter_config.json").write_text(json.dumps(config))
return adapter_dir, config
def _make_base_model(self, tmp_path, hidden=32):
"""Create a minimal base model directory with one shard."""
model_dir = tmp_path / "base_model"
model_dir.mkdir()
weights = {
"model.layers.0.self_attn.q_proj.weight": torch.randn(hidden, hidden),
"model.layers.0.self_attn.v_proj.weight": torch.randn(hidden, hidden),
"model.embed_tokens.weight": torch.randn(100, hidden),
}
safetensors.torch.save_file(weights, model_dir / "model.safetensors")
# Minimal config files
(model_dir / "config.json").write_text("{}")
return model_dir, weights
def test_find_lora_weights(self):
lora_state = {
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
8, 32
),
"base_model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
32, 8
),
}
a, b = find_lora_weights(lora_state, "layers.0.self_attn.q_proj.weight")
assert a is not None and b is not None
assert a.shape == (8, 32)
a, b = find_lora_weights(lora_state, "layers.0.self_attn.v_proj.weight")
assert a is None and b is None
def test_merge_tensor_basic(self):
hidden = 32
r = 8
alpha = 16
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
scale = alpha / r
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
}
config = {"r": r, "lora_alpha": alpha}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.q_proj.weight", lora_state, scale, config, "cpu"
)
assert was_merged
expected = base + scale * (lora_b @ lora_a)
assert torch.allclose(merged, expected, atol=1e-5)
def test_merge_tensor_rslora_scale(self):
"""RSLoRA should use alpha/sqrt(r) as scaling factor."""
r = 16
alpha = 32
standard_scale = alpha / r # 2.0
rslora_scale = alpha / math.sqrt(r) # 8.0
assert rslora_scale != standard_scale
assert abs(rslora_scale - 8.0) < 1e-6
def test_sharded_efficient_merge(self, tmp_path):
"""End-to-end test of shard-by-shard merge."""
hidden = 32
r = 8
alpha = 16
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
# Create LoRA weights
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(
hidden, r
),
"base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn(
hidden, r
),
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
# Verify output exists and has merged weights
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
scale = alpha / r
q_key = "model.layers.0.self_attn.q_proj.weight"
expected_q = base_weights[q_key] + scale * (
lora_state[f"base_model.model.{q_key[:-7]}.lora_B.weight"]
@ lora_state[f"base_model.model.{q_key[:-7]}.lora_A.weight"]
)
assert torch.allclose(merged[q_key], expected_q, atol=1e-5)
# Embedding should be unchanged
assert torch.equal(
merged["model.embed_tokens.weight"],
base_weights["model.embed_tokens.weight"],
)
def test_dora_merge(self):
"""DoRA merge applies magnitude normalization via PEFT."""
hidden = 32
r = 8
alpha = 16
scale = alpha / r
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
magnitude = torch.randn(hidden).abs() + 0.1
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": lora_a,
"base_model.model.layer.q_proj.lora_B.weight": lora_b,
"base_model.model.layer.q_proj.lora_magnitude_vector": magnitude,
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base,
"layer.q_proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)
assert was_merged
# The merge should differ from both base and base+delta (DoRA applies normalization)
delta = scale * (lora_b @ lora_a)
assert not torch.allclose(merged, base, atol=1e-3)
assert not torch.allclose(merged, base + delta, atol=1e-3)
def test_fuse_unfuse_moe_merge(self):
"""Test fuse→merge→unfuse for MoE expert weights (WeightConverter path)."""
from axolotl.cli.utils.lora_merge import _fuse_and_unfuse_with_merge
hidden = 16
intermediate = 32
num_experts = 4
r = 4
alpha = 8
scale = alpha / r
# Simulate checkpoint format: per-expert separate tensors
shard_tensors = {}
for i in range(num_experts):
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = (
torch.randn(intermediate, hidden)
)
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = (
torch.randn(intermediate, hidden)
)
shard_tensors[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = (
torch.randn(hidden, intermediate)
)
shard_tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn(
hidden, hidden
)
# LoRA targets the fused key (runtime format)
lora_state = {
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight": torch.randn(
r, hidden
),
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight": torch.randn(
intermediate * 2, r
),
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_A.weight": torch.randn(
r, intermediate
),
"base_model.model.model.layers.0.mlp.experts.down_proj.lora_B.weight": torch.randn(
hidden, r
),
}
# Build converters matching qwen2_moe pattern
from transformers.core_model_loading import (
Concatenate,
MergeModulelist,
WeightConverter,
)
converters = [
WeightConverter(
source_patterns=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
]
config = {"r": r, "lora_alpha": alpha}
result, merged_count, processed_keys = _fuse_and_unfuse_with_merge(
shard_tensors, converters, lora_state, scale, config, "cpu"
)
# Should have merged 2 LoRA targets (gate_up_proj and down_proj)
assert merged_count == 2
# Processed keys include original per-expert keys (removed) + fused keys (added)
assert len(processed_keys) > 0
# Output should be in fused format (runtime keys)
assert "model.layers.0.mlp.experts.gate_up_proj" in result
assert "model.layers.0.mlp.experts.down_proj" in result
# Per-expert keys should be removed
for i in range(num_experts):
assert f"model.layers.0.mlp.experts.{i}.gate_proj.weight" not in result
# Non-expert tensor should be passed through
assert "model.layers.0.self_attn.q_proj.weight" in result
# Verify fused tensors are 3D (stacked experts)
gate_up = result["model.layers.0.mlp.experts.gate_up_proj"]
assert gate_up.ndim == 3
assert gate_up.shape[0] == num_experts # [num_experts, intermediate*2, hidden]
# Verify the fused LoRA delta was applied correctly
# Reconstruct the fused base (stack per-expert, concat gate+up)
gate_stack = torch.stack(
[
shard_tensors[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"]
for i in range(num_experts)
]
)
up_stack = torch.stack(
[
shard_tensors[f"model.layers.0.mlp.experts.{i}.up_proj.weight"]
for i in range(num_experts)
]
)
base_fused = torch.cat([gate_stack, up_stack], dim=1)
lora_a = lora_state[
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_A.weight"
]
lora_b = lora_state[
"base_model.model.model.layers.0.mlp.experts.gate_up_proj.lora_B.weight"
]
expected_fused = base_fused + scale * (lora_b @ lora_a)
assert torch.allclose(gate_up, expected_fused, atol=1e-5)
def test_param_wrapper_merge_math(self):
"""ParamWrapper merge via PEFT's get_delta_weight matches manual einsum."""
num_experts = 4
r = 2
in_features = 8
out_features = 4
alpha = 4
# PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in)
base = torch.randn(num_experts, out_features, in_features)
lora_a = torch.randn(r * num_experts, in_features)
lora_b = torch.randn(out_features, r * num_experts)
config = {"r": r, "lora_alpha": alpha}
delta = _build_peft_layer_and_get_delta(
lora_a, lora_b, config, base, is_param_wrapper=True
)
assert delta.shape == base.shape
merged = base + delta
# Verify against manual einsum
scale = alpha / r
wa = lora_a.reshape(num_experts, r, in_features)
wb = lora_b.reshape(out_features, r, num_experts)
manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale
for e in range(num_experts):
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
f"Expert {e} mismatch"
)
def test_param_wrapper_nesting_dim_filter(self):
"""_find_param_wrapper_lora skips wrong-dimension LoRA at outer level."""
num_experts = 4
r = 2
# Outer LoRA (gate_up_proj): A=[r*E, 8], B=[16, r*E]
# Inner LoRA (down_proj via base_layer): A=[r*E, 16], B=[8, r*E]
lora_state = {
"base_model.model.mod.experts.lora_A.weight": torch.randn(
r * num_experts, 8
),
"base_model.model.mod.experts.lora_B.weight": torch.randn(
16, r * num_experts
),
"base_model.model.mod.experts.base_layer.lora_A.weight": torch.randn(
r * num_experts, 16
),
"base_model.model.mod.experts.base_layer.lora_B.weight": torch.randn(
8, r * num_experts
),
}
# gate_up_proj shape [4, 8, 16] — should match outer LoRA
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.gate_up_proj", tensor_shape=(4, 8, 16)
)
assert a is not None and name == "gate_up_proj"
assert a.shape == (r * num_experts, 8) # outer
# down_proj shape [4, 16, 8] — outer dims don't match, should find inner
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.down_proj", tensor_shape=(4, 16, 8)
)
assert a is not None and name == "down_proj"
assert a.shape == (r * num_experts, 16) # inner (base_layer)
# shape that matches neither — should return None
a, b, name = _find_param_wrapper_lora(
lora_state, "mod.experts.other", tensor_shape=(4, 99, 99)
)
assert a is None
def test_find_lora_weights_with_renamings(self):
"""Weight renamings let checkpoint keys match LoRA keys."""
lora_state = {
"base_model.model.layers.0.mlp.fc1.lora_A.weight": torch.randn(8, 32),
"base_model.model.layers.0.mlp.fc1.lora_B.weight": torch.randn(32, 8),
}
# Direct lookup fails (checkpoint has "ff0", LoRA has "fc1")
a, b = find_lora_weights(lora_state, "layers.0.mlp.ff0.weight")
assert a is None
# With renaming ff0 → fc1, it should match
a, b = find_lora_weights(
lora_state, "layers.0.mlp.ff0.weight", weight_renamings={"ff0": "fc1"}
)
assert a is not None
assert a.shape == (8, 32)
def test_unmatched_tensors_pass_through(self):
"""Tensors with no matching LoRA are returned unchanged."""
lora_state = {
"base_model.model.layer.q_proj.lora_A.weight": torch.randn(8, 32),
"base_model.model.layer.q_proj.lora_B.weight": torch.randn(32, 8),
}
# 1D tensor (layernorm) — never matched
ln = torch.randn(32)
merged, was_merged = _merge_tensor_with_lora(
ln, "layer.norm.weight", lora_state, 2.0, {}, "cpu"
)
assert not was_merged
assert torch.equal(merged, ln)
# 2D tensor with no matching key
unrelated = torch.randn(64, 32)
merged, was_merged = _merge_tensor_with_lora(
unrelated, "layer.other_proj.weight", lora_state, 2.0, {}, "cpu"
)
assert not was_merged
assert torch.equal(merged, unrelated)
def test_fan_in_fan_out_transpose(self):
"""fan_in_fan_out config transposes the LoRA delta."""
hidden = 16
r = 4
alpha = 4 # scale = 1.0
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
lora_state = {
"base_model.model.layer.proj.lora_A.weight": lora_a,
"base_model.model.layer.proj.lora_B.weight": lora_b,
}
config_normal = {"r": r, "lora_alpha": alpha}
config_fif = {"r": r, "lora_alpha": alpha, "fan_in_fan_out": True}
merged_normal, _ = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, 1.0, config_normal, "cpu"
)
merged_fif, _ = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, 1.0, config_fif, "cpu"
)
delta = (alpha / r) * (lora_b @ lora_a)
assert torch.allclose(merged_normal, base + delta, atol=1e-5)
assert torch.allclose(merged_fif, base + delta.T, atol=1e-5)
assert not torch.allclose(merged_normal, merged_fif, atol=1e-5)
def test_rslora_end_to_end(self, tmp_path):
"""RSLoRA adapter uses alpha/sqrt(r) scaling in sharded merge."""
hidden = 16
r = 16
alpha = 32
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_rslora=True)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
rslora_scale = alpha / math.sqrt(r) # 8.0, not 2.0
q_key = "model.layers.0.self_attn.q_proj.weight"
expected = base_weights[q_key] + rslora_scale * (lora_b @ lora_a)
assert torch.allclose(merged[q_key], expected, atol=1e-5)
# Confirm it differs from standard scale
wrong_scale = alpha / r # 2.0
wrong_expected = base_weights[q_key] + wrong_scale * (lora_b @ lora_a)
assert not torch.allclose(merged[q_key], wrong_expected, atol=1e-3)
def test_multi_shard_index_json(self, tmp_path):
"""Multi-shard merge generates a correct weight-map index."""
hidden = 16
r = 4
alpha = 8
model_dir = tmp_path / "base_model"
model_dir.mkdir()
(model_dir / "config.json").write_text("{}")
# Create 2 shards
shard1 = {"model.layers.0.weight": torch.randn(hidden, hidden)}
shard2 = {"model.layers.1.weight": torch.randn(hidden, hidden)}
safetensors.torch.save_file(
shard1, model_dir / "model-00001-of-00002.safetensors"
)
safetensors.torch.save_file(
shard2, model_dir / "model-00002-of-00002.safetensors"
)
# Write a base model index (will be skipped by copy_non_model_files)
base_index = {
"metadata": {},
"weight_map": {
"model.layers.0.weight": "model-00001-of-00002.safetensors",
"model.layers.1.weight": "model-00002-of-00002.safetensors",
},
}
(model_dir / "model.safetensors.index.json").write_text(json.dumps(base_index))
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha)
safetensors.torch.save_file({}, adapter_dir / "adapter_model.safetensors")
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
# Verify index was generated
index_path = output_dir / "model.safetensors.index.json"
assert index_path.exists()
with open(index_path) as f:
idx = json.load(f)
assert "weight_map" in idx
assert len(idx["weight_map"]) == 2
# Each key should map to a shard that exists
for _key, shard_name in idx["weight_map"].items():
assert (output_dir / shard_name).exists(), f"Missing shard: {shard_name}"
def test_dora_end_to_end(self, tmp_path):
"""DoRA merge through the full sharded merge pipeline."""
hidden = 16
r = 4
alpha = 8
model_dir, base_weights = self._make_base_model(tmp_path, hidden=hidden)
adapter_dir, _ = self._make_adapter(tmp_path, r=r, alpha=alpha, use_dora=True)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
magnitude = torch.randn(hidden).abs() + 0.1
lora_state = {
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": lora_a,
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": lora_b,
"base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector": magnitude,
}
safetensors.torch.save_file(
lora_state, adapter_dir / "adapter_model.safetensors"
)
output_dir = tmp_path / "output"
merge_lora_sharded_efficient(
base_model_path=model_dir,
lora_adapter_path=adapter_dir,
output_path=output_dir,
device="cpu",
)
merged = safetensors.torch.load_file(output_dir / "model.safetensors")
q_key = "model.layers.0.self_attn.q_proj.weight"
# Use PEFT's own get_delta_weight as the reference
delta = _build_peft_layer_and_get_delta(
lora_a,
lora_b,
{"r": r, "lora_alpha": alpha, "use_dora": True},
base_weights[q_key],
magnitude=magnitude,
)
expected = base_weights[q_key] + delta
assert torch.allclose(merged[q_key], expected, atol=1e-5)
# Verify it differs from standard (non-DoRA) merge
standard_delta = _build_peft_layer_and_get_delta(
lora_a,
lora_b,
{"r": r, "lora_alpha": alpha},
base_weights[q_key],
)
assert not torch.allclose(delta, standard_delta, atol=1e-3)
# v_proj has no LoRA weights — should be unchanged
v_key = "model.layers.0.self_attn.v_proj.weight"
assert torch.equal(merged[v_key], base_weights[v_key]), (
"v_proj should be unchanged (no LoRA weights for it)"
)
def test_dora_missing_magnitude_raises(self):
"""DoRA with missing magnitude vector raises an explicit error."""
hidden = 16
r = 4
alpha = 8
scale = alpha / r
base = torch.randn(hidden, hidden)
lora_a = torch.randn(r, hidden)
lora_b = torch.randn(hidden, r)
# No magnitude vector in lora_state
lora_state = {
"base_model.model.layer.proj.lora_A.weight": lora_a,
"base_model.model.layer.proj.lora_B.weight": lora_b,
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
_merge_tensor_with_lora(
base,
"layer.proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)