Increased test coverage for lora/qlora (#3147)
* config_val tests * remove config val(not needed) * config validation * parameter freeze validation * merge/unmerge tests * removal unwanted * rename * lint * updated lint * Update tests/utils/lora/test_config_validation_lora.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * pytest skip + mock fix * nitpicks * revert some nitpicks --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
92
tests/utils/lora/test_config_validation_lora.py
Normal file
92
tests/utils/lora/test_config_validation_lora.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestLoRAConfigValidation:
|
||||
"""Test suite for LoRA/QLoRA configuration validation"""
|
||||
|
||||
def test_basic_configuration_validation(self):
|
||||
"""Test basic LoRA configuration validation"""
|
||||
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
|
||||
result = validate_config(valid_config)
|
||||
assert result["adapter"] == "lora"
|
||||
|
||||
with pytest.raises(ValueError, match="not compatible with DoRA"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_mlp_kernel": True,
|
||||
"peft_use_dora": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
def test_qlora_4bit_validation(self):
|
||||
"""Test QLoRA 4-bit configuration validation"""
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_compute_dtype": "float16",
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(valid_config)
|
||||
assert result["adapter"] == "qlora"
|
||||
assert result["load_in_4bit"] is True
|
||||
|
||||
# Test QLoRA without 4-bit (should fail via PEFT validation)
|
||||
with pytest.raises(ValueError, match=r"Require cfg\.load_in_4bit"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": False,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
# Test QLoRA with 8-bit (incompatible)
|
||||
with pytest.raises(ValueError, match="Can't load qlora in 8bit"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_8bit": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
261
tests/utils/lora/test_freeze_lora.py
Normal file
261
tests/utils/lora/test_freeze_lora.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import importlib.util
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None
|
||||
|
||||
|
||||
class TestLoRAParameterFreezing:
|
||||
"""Test suite for LoRA parameter freezing validation."""
|
||||
|
||||
def setup_method(self):
|
||||
self.dtype = torch.float32
|
||||
|
||||
def create_mock_lora_layer(
|
||||
self, has_adapters=True, adapters_disabled=False, merged=False
|
||||
):
|
||||
"""Create a mock LoRA layer for testing."""
|
||||
mock_layer = Mock()
|
||||
|
||||
base_layer = Mock()
|
||||
base_layer.weight = torch.randn(512, 256, dtype=self.dtype)
|
||||
base_layer.bias = torch.randn(512, dtype=self.dtype)
|
||||
|
||||
if has_adapters:
|
||||
mock_layer.base_layer = base_layer
|
||||
mock_layer.disable_adapters = adapters_disabled
|
||||
mock_layer.merged = merged
|
||||
|
||||
mock_layer.active_adapters = ["default"]
|
||||
mock_layer.lora_A = {"default": Mock()}
|
||||
mock_layer.lora_B = {"default": Mock()}
|
||||
mock_layer.scaling = {"default": 0.1}
|
||||
|
||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
else:
|
||||
mock_layer.weight = base_layer.weight
|
||||
mock_layer.bias = base_layer.bias
|
||||
|
||||
return mock_layer
|
||||
|
||||
def test_parameter_freezing_adapters_disabled(self):
|
||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_freezing_adapters_merged(self):
|
||||
"""Test that LoRA parameters are None when adapters are merged."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_freezing_no_adapters(self):
|
||||
"""Test parameter behavior when no adapters are present."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_active_adapters_enabled(self):
|
||||
"""Test that LoRA parameters are returned when adapters are active."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# All parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
assert s == 0.1
|
||||
|
||||
def test_parameter_shapes_consistency(self):
|
||||
"""Test that parameter shapes are consistent when active."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Check shape consistency
|
||||
assert W.shape == (512, 256)
|
||||
assert b.shape == (512,)
|
||||
assert A.shape == (16, 256)
|
||||
assert B.shape == (512, 16)
|
||||
|
||||
def test_parameter_dtypes_consistency(self):
|
||||
"""Test that parameter dtypes are consistent."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert W.dtype == self.dtype
|
||||
assert b.dtype == self.dtype
|
||||
assert A.dtype == self.dtype
|
||||
assert B.dtype == self.dtype
|
||||
|
||||
def test_quantization_state_handling(self):
|
||||
"""Test that quantization state is properly handled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True)
|
||||
|
||||
quant_state_mock = Mock()
|
||||
layer.base_layer.weight.quant_state = quant_state_mock
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert quant_state == quant_state_mock
|
||||
|
||||
def test_multiple_adapters_active_adapter_selection(self):
|
||||
"""Test that the correct adapter is selected when multiple adapters exist."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
layer.lora_A["adapter2"] = Mock()
|
||||
layer.lora_B["adapter2"] = Mock()
|
||||
layer.scaling["adapter2"] = 0.2
|
||||
|
||||
layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
|
||||
layer.active_adapters = ["adapter2"]
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert s == 0.2
|
||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||
assert torch.equal(B, layer.lora_B["adapter2"].weight)
|
||||
|
||||
|
||||
class TestLoRAParameterFreezingIntegration:
|
||||
"""Integration tests for parameter freezing with actual LoRA layers."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
||||
)
|
||||
def test_parameter_freezing_with_real_lora_layer(self):
|
||||
"""Test parameter freezing with actual PEFT LoRA layer."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
base_model = SimpleModel()
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["linear"],
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
lora_layer = model.base_model.model.linear
|
||||
# Test with adapters enabled
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
# Test with adapters disabled
|
||||
model.disable_adapter_layers()
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
||||
)
|
||||
def test_parameter_freezing_gradient_behavior(self):
|
||||
"""Test that frozen parameters don't receive gradients."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
base_model = SimpleModel()
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["linear"],
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
x = torch.randn(1, 256)
|
||||
target = torch.randn(1, 512)
|
||||
model.enable_adapter_layers()
|
||||
output = model(x)
|
||||
loss = nn.MSELoss()(output, target)
|
||||
loss.backward()
|
||||
lora_layer = model.base_model.model.linear
|
||||
has_lora_grads = any(
|
||||
param.grad is not None
|
||||
for name, param in lora_layer.named_parameters()
|
||||
if "lora_" in name
|
||||
)
|
||||
assert has_lora_grads, (
|
||||
"LoRA parameters should have gradients when adapters are enabled"
|
||||
)
|
||||
model.zero_grad()
|
||||
model.disable_adapter_layers()
|
||||
output = model(x)
|
||||
loss = nn.MSELoss()(output, target)
|
||||
any_requires_grad = any(param.requires_grad for param in model.parameters())
|
||||
if any_requires_grad:
|
||||
loss.backward()
|
||||
has_lora_grads_disabled = any(
|
||||
param.grad is not None
|
||||
for name, param in lora_layer.named_parameters()
|
||||
if "lora_" in name
|
||||
)
|
||||
assert not has_lora_grads_disabled, (
|
||||
"LoRA parameters should not have gradients when adapters are disabled"
|
||||
)
|
||||
model.zero_grad()
|
||||
del model, base_model, lora_layer, x, target, output, loss
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
181
tests/utils/lora/test_merge_lora.py
Normal file
181
tests/utils/lora/test_merge_lora.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.cli.merge_lora import do_merge_lora
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestAdapterMergeUnmerge:
|
||||
"""Test suite for LoRA adapter merging/unmerging functionality"""
|
||||
|
||||
def setup_method(self):
|
||||
self.dtype = torch.float32
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def create_mock_base_model(self, vocab_size=1000, hidden_size=256):
|
||||
"""Create a mock base model with linear layers"""
|
||||
mock_model = Mock()
|
||||
|
||||
mock_model.config = Mock()
|
||||
mock_model.config.vocab_size = vocab_size
|
||||
mock_model.config.hidden_size = hidden_size
|
||||
|
||||
mock_model.q_proj = Mock()
|
||||
mock_model.q_proj.weight = torch.randn(
|
||||
hidden_size, hidden_size, dtype=self.dtype
|
||||
)
|
||||
mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
|
||||
|
||||
mock_model.v_proj = Mock()
|
||||
mock_model.v_proj.weight = torch.randn(
|
||||
hidden_size, hidden_size, dtype=self.dtype
|
||||
)
|
||||
mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
|
||||
|
||||
return mock_model
|
||||
|
||||
def create_mock_lora_model(self, base_model, r=8, alpha=16):
|
||||
"""Create a mock LoRA model wrapping the base model"""
|
||||
mock_lora_model = Mock()
|
||||
mock_lora_model.base_model = base_model
|
||||
|
||||
mock_lora_model.merge_and_unload = None
|
||||
mock_lora_model.to = Mock(return_value=mock_lora_model)
|
||||
|
||||
mock_lora_model.generation_config = Mock()
|
||||
mock_lora_model.config = Mock()
|
||||
|
||||
self.original_q_weight = base_model.q_proj.weight.clone()
|
||||
self.original_v_weight = base_model.v_proj.weight.clone()
|
||||
|
||||
mock_lora_model.peft_config = {"default": Mock()}
|
||||
mock_lora_model.peft_config["default"].r = r
|
||||
mock_lora_model.peft_config["default"].lora_alpha = alpha
|
||||
|
||||
self.lora_A_q = torch.randn(
|
||||
r, base_model.q_proj.weight.shape[1], dtype=self.dtype
|
||||
)
|
||||
self.lora_B_q = torch.randn(
|
||||
base_model.q_proj.weight.shape[0], r, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.lora_A_v = torch.randn(
|
||||
r, base_model.v_proj.weight.shape[1], dtype=self.dtype
|
||||
)
|
||||
self.lora_B_v = torch.randn(
|
||||
base_model.v_proj.weight.shape[0], r, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.scaling = alpha / r
|
||||
|
||||
def mock_merge_and_unload(progressbar=False):
|
||||
"""Simulate the actual merge operation"""
|
||||
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
|
||||
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
|
||||
|
||||
base_model.q_proj.weight = self.original_q_weight + delta_q
|
||||
base_model.v_proj.weight = self.original_v_weight + delta_v
|
||||
|
||||
return base_model
|
||||
|
||||
mock_lora_model.merge_and_unload = mock_merge_and_unload
|
||||
return mock_lora_model
|
||||
|
||||
def test_basic_lora_merge_unmerge_cycle(self):
|
||||
"""Test: original_weights -> merge -> unmerge -> should equal original_weights"""
|
||||
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
|
||||
original_q_weight = self.original_q_weight.clone()
|
||||
original_v_weight = self.original_v_weight.clone()
|
||||
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
|
||||
assert not torch.equal(merged_model.q_proj.weight, original_q_weight)
|
||||
assert not torch.equal(merged_model.v_proj.weight, original_v_weight)
|
||||
|
||||
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
|
||||
|
||||
unmerged_q_weight = merged_model.q_proj.weight - delta_q
|
||||
unmerged_v_weight = merged_model.v_proj.weight - delta_v
|
||||
|
||||
assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6)
|
||||
assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6)
|
||||
|
||||
def test_merge_weight_calculation_accuracy(self):
|
||||
"""Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)"""
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32)
|
||||
|
||||
expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
expected_merged_q = self.original_q_weight + expected_delta_q
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
|
||||
assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6)
|
||||
|
||||
@patch("axolotl.cli.merge_lora.load_model_and_tokenizer")
|
||||
def test_cli_do_merge_functionality(self, mock_load_model, tmp_path):
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
tokenizer = Mock()
|
||||
processor = None
|
||||
|
||||
mock_load_model.return_value = (lora_model, tokenizer, processor)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_safetensors": True,
|
||||
"torch_dtype": torch.float32,
|
||||
"local_rank": 0,
|
||||
"output_dir": str(tmp_path),
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("pathlib.Path.mkdir"),
|
||||
patch.object(base_model, "save_pretrained") as mock_save_model,
|
||||
patch.object(tokenizer, "save_pretrained") as mock_save_tokenizer,
|
||||
):
|
||||
do_merge_lora(cfg=cfg)
|
||||
|
||||
mock_save_model.assert_called_once()
|
||||
mock_save_tokenizer.assert_called_once()
|
||||
|
||||
def test_quantized_model_merge_compatibility(self):
|
||||
"""Test 4-bit/8-bit model merging scenarios"""
|
||||
base_model = self.create_mock_base_model()
|
||||
|
||||
# Mock quantized weights
|
||||
base_model.q_proj.weight.quant_state = Mock()
|
||||
base_model.q_proj.weight.quant_state.dtype = torch.uint8
|
||||
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
assert merged_model is not None
|
||||
|
||||
@patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": ""})
|
||||
def test_memory_efficient_merge_with_cpu_offload(self, tmp_path):
|
||||
"""Test lora_on_cpu configuration during merge"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"lora_on_cpu": True,
|
||||
"save_safetensors": True,
|
||||
"output_dir": str(tmp_path),
|
||||
"local_rank": 0,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("axolotl.cli.merge_lora.load_model_and_tokenizer") as mock_load:
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
mock_load.return_value = (lora_model, Mock(), None)
|
||||
|
||||
with patch("pathlib.Path.mkdir"), patch("torch.save"):
|
||||
do_merge_lora(cfg=cfg)
|
||||
|
||||
assert mock_load.called
|
||||
Reference in New Issue
Block a user