Increased test coverage for lora/qlora (#3147)

* config_val tests

* remove config val(not needed)

* config validation

* parameter freeze validation

* merge/unmerge tests

* removal unwanted

* rename

* lint

* updated lint

* Update tests/utils/lora/test_config_validation_lora.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* pytest skip + mock fix

* nitpicks

* revert some nitpicks

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
VED
2026-01-06 22:14:48 +05:30
committed by GitHub
parent 7bf6f70e96
commit e7f0d4ba5b
3 changed files with 534 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestLoRAConfigValidation:
"""Test suite for LoRA/QLoRA configuration validation"""
def test_basic_configuration_validation(self):
"""Test basic LoRA configuration validation"""
valid_config = DictDefault(
{
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.1,
"lora_target_modules": ["q_proj", "v_proj"],
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["adapter"] == "lora"
with pytest.raises(ValueError, match="not compatible with DoRA"):
invalid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_qlora_4bit_validation(self):
"""Test QLoRA 4-bit configuration validation"""
valid_config = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"bnb_4bit_compute_dtype": "float16",
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
# Test QLoRA without 4-bit (should fail via PEFT validation)
with pytest.raises(ValueError, match=r"Require cfg\.load_in_4bit"):
invalid_config = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": False,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
# Test QLoRA with 8-bit (incompatible)
with pytest.raises(ValueError, match="Can't load qlora in 8bit"):
invalid_config = DictDefault(
{
"adapter": "qlora",
"load_in_8bit": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)

View File

@@ -0,0 +1,261 @@
import importlib.util
from unittest.mock import Mock
import pytest
import torch
import torch.nn as nn
from axolotl.kernels.lora import get_lora_parameters
PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None
class TestLoRAParameterFreezing:
"""Test suite for LoRA parameter freezing validation."""
def setup_method(self):
self.dtype = torch.float32
def create_mock_lora_layer(
self, has_adapters=True, adapters_disabled=False, merged=False
):
"""Create a mock LoRA layer for testing."""
mock_layer = Mock()
base_layer = Mock()
base_layer.weight = torch.randn(512, 256, dtype=self.dtype)
base_layer.bias = torch.randn(512, dtype=self.dtype)
if has_adapters:
mock_layer.base_layer = base_layer
mock_layer.disable_adapters = adapters_disabled
mock_layer.merged = merged
mock_layer.active_adapters = ["default"]
mock_layer.lora_A = {"default": Mock()}
mock_layer.lora_B = {"default": Mock()}
mock_layer.scaling = {"default": 0.1}
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
else:
mock_layer.weight = base_layer.weight
mock_layer.bias = base_layer.bias
return mock_layer
def test_parameter_freezing_adapters_disabled(self):
"""Test that LoRA parameters are None when adapters are disabled."""
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
assert b is not None
# LoRA parameters should be None (frozen)
assert A is None
assert B is None
assert s is None
def test_parameter_freezing_adapters_merged(self):
"""Test that LoRA parameters are None when adapters are merged."""
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
assert b is not None
# LoRA parameters should be None (frozen)
assert A is None
assert B is None
assert s is None
def test_parameter_freezing_no_adapters(self):
"""Test parameter behavior when no adapters are present."""
layer = self.create_mock_lora_layer(has_adapters=False)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
assert b is not None
# LoRA parameters should be None (frozen)
assert A is None
assert B is None
assert s is None
def test_parameter_active_adapters_enabled(self):
"""Test that LoRA parameters are returned when adapters are active."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# All parameters should be returned
assert W is not None
assert b is not None
assert A is not None
assert B is not None
assert s is not None
assert s == 0.1
def test_parameter_shapes_consistency(self):
"""Test that parameter shapes are consistent when active."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Check shape consistency
assert W.shape == (512, 256)
assert b.shape == (512,)
assert A.shape == (16, 256)
assert B.shape == (512, 16)
def test_parameter_dtypes_consistency(self):
"""Test that parameter dtypes are consistent."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert W.dtype == self.dtype
assert b.dtype == self.dtype
assert A.dtype == self.dtype
assert B.dtype == self.dtype
def test_quantization_state_handling(self):
"""Test that quantization state is properly handled."""
layer = self.create_mock_lora_layer(has_adapters=True)
quant_state_mock = Mock()
layer.base_layer.weight.quant_state = quant_state_mock
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert quant_state == quant_state_mock
def test_multiple_adapters_active_adapter_selection(self):
"""Test that the correct adapter is selected when multiple adapters exist."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
layer.lora_A["adapter2"] = Mock()
layer.lora_B["adapter2"] = Mock()
layer.scaling["adapter2"] = 0.2
layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype)
layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype)
layer.active_adapters = ["adapter2"]
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert s == 0.2
assert torch.equal(A, layer.lora_A["adapter2"].weight)
assert torch.equal(B, layer.lora_B["adapter2"].weight)
class TestLoRAParameterFreezingIntegration:
"""Integration tests for parameter freezing with actual LoRA layers."""
@pytest.mark.skipif(
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
)
def test_parameter_freezing_with_real_lora_layer(self):
"""Test parameter freezing with actual PEFT LoRA layer."""
from peft import LoraConfig, get_peft_model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(256, 512)
def forward(self, x):
return self.linear(x)
base_model = SimpleModel()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["linear"],
lora_dropout=0.1,
)
model = get_peft_model(base_model, lora_config)
lora_layer = model.base_model.model.linear
# Test with adapters enabled
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
assert A is not None
assert B is not None
assert s is not None
# Test with adapters disabled
model.disable_adapter_layers()
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
assert A is None
assert B is None
assert s is None
@pytest.mark.skipif(
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
)
def test_parameter_freezing_gradient_behavior(self):
"""Test that frozen parameters don't receive gradients."""
from peft import LoraConfig, get_peft_model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(256, 512)
def forward(self, x):
return self.linear(x)
base_model = SimpleModel()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["linear"],
lora_dropout=0.1,
)
model = get_peft_model(base_model, lora_config)
x = torch.randn(1, 256)
target = torch.randn(1, 512)
model.enable_adapter_layers()
output = model(x)
loss = nn.MSELoss()(output, target)
loss.backward()
lora_layer = model.base_model.model.linear
has_lora_grads = any(
param.grad is not None
for name, param in lora_layer.named_parameters()
if "lora_" in name
)
assert has_lora_grads, (
"LoRA parameters should have gradients when adapters are enabled"
)
model.zero_grad()
model.disable_adapter_layers()
output = model(x)
loss = nn.MSELoss()(output, target)
any_requires_grad = any(param.requires_grad for param in model.parameters())
if any_requires_grad:
loss.backward()
has_lora_grads_disabled = any(
param.grad is not None
for name, param in lora_layer.named_parameters()
if "lora_" in name
)
assert not has_lora_grads_disabled, (
"LoRA parameters should not have gradients when adapters are disabled"
)
model.zero_grad()
del model, base_model, lora_layer, x, target, output, loss
torch.cuda.empty_cache() if torch.cuda.is_available() else None

View File

@@ -0,0 +1,181 @@
from unittest.mock import Mock, patch
import torch
from axolotl.cli.merge_lora import do_merge_lora
from axolotl.utils.dict import DictDefault
class TestAdapterMergeUnmerge:
"""Test suite for LoRA adapter merging/unmerging functionality"""
def setup_method(self):
self.dtype = torch.float32
self.device = torch.device("cpu")
def create_mock_base_model(self, vocab_size=1000, hidden_size=256):
"""Create a mock base model with linear layers"""
mock_model = Mock()
mock_model.config = Mock()
mock_model.config.vocab_size = vocab_size
mock_model.config.hidden_size = hidden_size
mock_model.q_proj = Mock()
mock_model.q_proj.weight = torch.randn(
hidden_size, hidden_size, dtype=self.dtype
)
mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
mock_model.v_proj = Mock()
mock_model.v_proj.weight = torch.randn(
hidden_size, hidden_size, dtype=self.dtype
)
mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
return mock_model
def create_mock_lora_model(self, base_model, r=8, alpha=16):
"""Create a mock LoRA model wrapping the base model"""
mock_lora_model = Mock()
mock_lora_model.base_model = base_model
mock_lora_model.merge_and_unload = None
mock_lora_model.to = Mock(return_value=mock_lora_model)
mock_lora_model.generation_config = Mock()
mock_lora_model.config = Mock()
self.original_q_weight = base_model.q_proj.weight.clone()
self.original_v_weight = base_model.v_proj.weight.clone()
mock_lora_model.peft_config = {"default": Mock()}
mock_lora_model.peft_config["default"].r = r
mock_lora_model.peft_config["default"].lora_alpha = alpha
self.lora_A_q = torch.randn(
r, base_model.q_proj.weight.shape[1], dtype=self.dtype
)
self.lora_B_q = torch.randn(
base_model.q_proj.weight.shape[0], r, dtype=self.dtype
)
self.lora_A_v = torch.randn(
r, base_model.v_proj.weight.shape[1], dtype=self.dtype
)
self.lora_B_v = torch.randn(
base_model.v_proj.weight.shape[0], r, dtype=self.dtype
)
self.scaling = alpha / r
def mock_merge_and_unload(progressbar=False):
"""Simulate the actual merge operation"""
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
base_model.q_proj.weight = self.original_q_weight + delta_q
base_model.v_proj.weight = self.original_v_weight + delta_v
return base_model
mock_lora_model.merge_and_unload = mock_merge_and_unload
return mock_lora_model
def test_basic_lora_merge_unmerge_cycle(self):
"""Test: original_weights -> merge -> unmerge -> should equal original_weights"""
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
original_q_weight = self.original_q_weight.clone()
original_v_weight = self.original_v_weight.clone()
merged_model = lora_model.merge_and_unload()
assert not torch.equal(merged_model.q_proj.weight, original_q_weight)
assert not torch.equal(merged_model.v_proj.weight, original_v_weight)
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
unmerged_q_weight = merged_model.q_proj.weight - delta_q
unmerged_v_weight = merged_model.v_proj.weight - delta_v
assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6)
assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6)
def test_merge_weight_calculation_accuracy(self):
"""Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)"""
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32)
expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
expected_merged_q = self.original_q_weight + expected_delta_q
merged_model = lora_model.merge_and_unload()
assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6)
@patch("axolotl.cli.merge_lora.load_model_and_tokenizer")
def test_cli_do_merge_functionality(self, mock_load_model, tmp_path):
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
tokenizer = Mock()
processor = None
mock_load_model.return_value = (lora_model, tokenizer, processor)
cfg = DictDefault(
{
"save_safetensors": True,
"torch_dtype": torch.float32,
"local_rank": 0,
"output_dir": str(tmp_path),
}
)
with (
patch("pathlib.Path.mkdir"),
patch.object(base_model, "save_pretrained") as mock_save_model,
patch.object(tokenizer, "save_pretrained") as mock_save_tokenizer,
):
do_merge_lora(cfg=cfg)
mock_save_model.assert_called_once()
mock_save_tokenizer.assert_called_once()
def test_quantized_model_merge_compatibility(self):
"""Test 4-bit/8-bit model merging scenarios"""
base_model = self.create_mock_base_model()
# Mock quantized weights
base_model.q_proj.weight.quant_state = Mock()
base_model.q_proj.weight.quant_state.dtype = torch.uint8
lora_model = self.create_mock_lora_model(base_model)
merged_model = lora_model.merge_and_unload()
assert merged_model is not None
@patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": ""})
def test_memory_efficient_merge_with_cpu_offload(self, tmp_path):
"""Test lora_on_cpu configuration during merge"""
cfg = DictDefault(
{
"lora_on_cpu": True,
"save_safetensors": True,
"output_dir": str(tmp_path),
"local_rank": 0,
}
)
with patch("axolotl.cli.merge_lora.load_model_and_tokenizer") as mock_load:
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
mock_load.return_value = (lora_model, Mock(), None)
with patch("pathlib.Path.mkdir"), patch("torch.save"):
do_merge_lora(cfg=cfg)
assert mock_load.called