From e7f0d4ba5be91a452850d6e3f1656c97ec1d0fd3 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:14:48 +0530 Subject: [PATCH] Increased test coverage for lora/qlora (#3147) * config_val tests * remove config val(not needed) * config validation * parameter freeze validation * merge/unmerge tests * removal unwanted * rename * lint * updated lint * Update tests/utils/lora/test_config_validation_lora.py Co-authored-by: NanoCode012 * pytest skip + mock fix * nitpicks * revert some nitpicks --------- Co-authored-by: NanoCode012 --- .../utils/lora/test_config_validation_lora.py | 92 ++++++ tests/utils/lora/test_freeze_lora.py | 261 ++++++++++++++++++ tests/utils/lora/test_merge_lora.py | 181 ++++++++++++ 3 files changed, 534 insertions(+) create mode 100644 tests/utils/lora/test_config_validation_lora.py create mode 100644 tests/utils/lora/test_freeze_lora.py create mode 100644 tests/utils/lora/test_merge_lora.py diff --git a/tests/utils/lora/test_config_validation_lora.py b/tests/utils/lora/test_config_validation_lora.py new file mode 100644 index 000000000..a22e2a5b7 --- /dev/null +++ b/tests/utils/lora/test_config_validation_lora.py @@ -0,0 +1,92 @@ +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class TestLoRAConfigValidation: + """Test suite for LoRA/QLoRA configuration validation""" + + def test_basic_configuration_validation(self): + """Test basic LoRA configuration validation""" + + valid_config = DictDefault( + { + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.1, + "lora_target_modules": ["q_proj", "v_proj"], + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + + result = validate_config(valid_config) + assert result["adapter"] == "lora" + + with pytest.raises(ValueError, match="not compatible with DoRA"): + invalid_config = DictDefault( + { + "adapter": "lora", + "lora_mlp_kernel": True, + "peft_use_dora": True, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + validate_config(invalid_config) + + def test_qlora_4bit_validation(self): + """Test QLoRA 4-bit configuration validation""" + valid_config = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "bnb_4bit_compute_dtype": "float16", + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + result = validate_config(valid_config) + assert result["adapter"] == "qlora" + assert result["load_in_4bit"] is True + + # Test QLoRA without 4-bit (should fail via PEFT validation) + with pytest.raises(ValueError, match=r"Require cfg\.load_in_4bit"): + invalid_config = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": False, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + validate_config(invalid_config) + + # Test QLoRA with 8-bit (incompatible) + with pytest.raises(ValueError, match="Can't load qlora in 8bit"): + invalid_config = DictDefault( + { + "adapter": "qlora", + "load_in_8bit": True, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + validate_config(invalid_config) diff --git a/tests/utils/lora/test_freeze_lora.py b/tests/utils/lora/test_freeze_lora.py new file mode 100644 index 000000000..da90c1826 --- /dev/null +++ b/tests/utils/lora/test_freeze_lora.py @@ -0,0 +1,261 @@ +import importlib.util +from unittest.mock import Mock + +import pytest +import torch +import torch.nn as nn + +from axolotl.kernels.lora import get_lora_parameters + +PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None + + +class TestLoRAParameterFreezing: + """Test suite for LoRA parameter freezing validation.""" + + def setup_method(self): + self.dtype = torch.float32 + + def create_mock_lora_layer( + self, has_adapters=True, adapters_disabled=False, merged=False + ): + """Create a mock LoRA layer for testing.""" + mock_layer = Mock() + + base_layer = Mock() + base_layer.weight = torch.randn(512, 256, dtype=self.dtype) + base_layer.bias = torch.randn(512, dtype=self.dtype) + + if has_adapters: + mock_layer.base_layer = base_layer + mock_layer.disable_adapters = adapters_disabled + mock_layer.merged = merged + + mock_layer.active_adapters = ["default"] + mock_layer.lora_A = {"default": Mock()} + mock_layer.lora_B = {"default": Mock()} + mock_layer.scaling = {"default": 0.1} + + mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype) + mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype) + else: + mock_layer.weight = base_layer.weight + mock_layer.bias = base_layer.bias + + return mock_layer + + def test_parameter_freezing_adapters_disabled(self): + """Test that LoRA parameters are None when adapters are disabled.""" + layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True) + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + # Base parameters should be returned + assert W is not None + assert b is not None + # LoRA parameters should be None (frozen) + assert A is None + assert B is None + assert s is None + + def test_parameter_freezing_adapters_merged(self): + """Test that LoRA parameters are None when adapters are merged.""" + layer = self.create_mock_lora_layer(has_adapters=True, merged=True) + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + # Base parameters should be returned + assert W is not None + assert b is not None + + # LoRA parameters should be None (frozen) + assert A is None + assert B is None + assert s is None + + def test_parameter_freezing_no_adapters(self): + """Test parameter behavior when no adapters are present.""" + layer = self.create_mock_lora_layer(has_adapters=False) + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + # Base parameters should be returned + assert W is not None + assert b is not None + + # LoRA parameters should be None (frozen) + assert A is None + assert B is None + assert s is None + + def test_parameter_active_adapters_enabled(self): + """Test that LoRA parameters are returned when adapters are active.""" + layer = self.create_mock_lora_layer( + has_adapters=True, adapters_disabled=False, merged=False + ) + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + # All parameters should be returned + assert W is not None + assert b is not None + assert A is not None + assert B is not None + assert s is not None + assert s == 0.1 + + def test_parameter_shapes_consistency(self): + """Test that parameter shapes are consistent when active.""" + layer = self.create_mock_lora_layer( + has_adapters=True, adapters_disabled=False, merged=False + ) + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + # Check shape consistency + assert W.shape == (512, 256) + assert b.shape == (512,) + assert A.shape == (16, 256) + assert B.shape == (512, 16) + + def test_parameter_dtypes_consistency(self): + """Test that parameter dtypes are consistent.""" + layer = self.create_mock_lora_layer( + has_adapters=True, adapters_disabled=False, merged=False + ) + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + assert W.dtype == self.dtype + assert b.dtype == self.dtype + assert A.dtype == self.dtype + assert B.dtype == self.dtype + + def test_quantization_state_handling(self): + """Test that quantization state is properly handled.""" + layer = self.create_mock_lora_layer(has_adapters=True) + + quant_state_mock = Mock() + layer.base_layer.weight.quant_state = quant_state_mock + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + assert quant_state == quant_state_mock + + def test_multiple_adapters_active_adapter_selection(self): + """Test that the correct adapter is selected when multiple adapters exist.""" + layer = self.create_mock_lora_layer( + has_adapters=True, adapters_disabled=False, merged=False + ) + + layer.lora_A["adapter2"] = Mock() + layer.lora_B["adapter2"] = Mock() + layer.scaling["adapter2"] = 0.2 + + layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype) + layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype) + + layer.active_adapters = ["adapter2"] + + W, b, quant_state, A, B, s = get_lora_parameters(layer) + + assert s == 0.2 + assert torch.equal(A, layer.lora_A["adapter2"].weight) + assert torch.equal(B, layer.lora_B["adapter2"].weight) + + +class TestLoRAParameterFreezingIntegration: + """Integration tests for parameter freezing with actual LoRA layers.""" + + @pytest.mark.skipif( + not PEFT_AVAILABLE, reason="PEFT not available for integration tests" + ) + def test_parameter_freezing_with_real_lora_layer(self): + """Test parameter freezing with actual PEFT LoRA layer.""" + from peft import LoraConfig, get_peft_model + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(256, 512) + + def forward(self, x): + return self.linear(x) + + base_model = SimpleModel() + lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["linear"], + lora_dropout=0.1, + ) + model = get_peft_model(base_model, lora_config) + lora_layer = model.base_model.model.linear + # Test with adapters enabled + W, b, quant_state, A, B, s = get_lora_parameters(lora_layer) + assert A is not None + assert B is not None + assert s is not None + # Test with adapters disabled + model.disable_adapter_layers() + W, b, quant_state, A, B, s = get_lora_parameters(lora_layer) + assert A is None + assert B is None + assert s is None + + @pytest.mark.skipif( + not PEFT_AVAILABLE, reason="PEFT not available for integration tests" + ) + def test_parameter_freezing_gradient_behavior(self): + """Test that frozen parameters don't receive gradients.""" + from peft import LoraConfig, get_peft_model + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(256, 512) + + def forward(self, x): + return self.linear(x) + + base_model = SimpleModel() + lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["linear"], + lora_dropout=0.1, + ) + model = get_peft_model(base_model, lora_config) + x = torch.randn(1, 256) + target = torch.randn(1, 512) + model.enable_adapter_layers() + output = model(x) + loss = nn.MSELoss()(output, target) + loss.backward() + lora_layer = model.base_model.model.linear + has_lora_grads = any( + param.grad is not None + for name, param in lora_layer.named_parameters() + if "lora_" in name + ) + assert has_lora_grads, ( + "LoRA parameters should have gradients when adapters are enabled" + ) + model.zero_grad() + model.disable_adapter_layers() + output = model(x) + loss = nn.MSELoss()(output, target) + any_requires_grad = any(param.requires_grad for param in model.parameters()) + if any_requires_grad: + loss.backward() + has_lora_grads_disabled = any( + param.grad is not None + for name, param in lora_layer.named_parameters() + if "lora_" in name + ) + assert not has_lora_grads_disabled, ( + "LoRA parameters should not have gradients when adapters are disabled" + ) + model.zero_grad() + del model, base_model, lora_layer, x, target, output, loss + torch.cuda.empty_cache() if torch.cuda.is_available() else None diff --git a/tests/utils/lora/test_merge_lora.py b/tests/utils/lora/test_merge_lora.py new file mode 100644 index 000000000..8edccafb9 --- /dev/null +++ b/tests/utils/lora/test_merge_lora.py @@ -0,0 +1,181 @@ +from unittest.mock import Mock, patch + +import torch + +from axolotl.cli.merge_lora import do_merge_lora +from axolotl.utils.dict import DictDefault + + +class TestAdapterMergeUnmerge: + """Test suite for LoRA adapter merging/unmerging functionality""" + + def setup_method(self): + self.dtype = torch.float32 + self.device = torch.device("cpu") + + def create_mock_base_model(self, vocab_size=1000, hidden_size=256): + """Create a mock base model with linear layers""" + mock_model = Mock() + + mock_model.config = Mock() + mock_model.config.vocab_size = vocab_size + mock_model.config.hidden_size = hidden_size + + mock_model.q_proj = Mock() + mock_model.q_proj.weight = torch.randn( + hidden_size, hidden_size, dtype=self.dtype + ) + mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype) + + mock_model.v_proj = Mock() + mock_model.v_proj.weight = torch.randn( + hidden_size, hidden_size, dtype=self.dtype + ) + mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype) + + return mock_model + + def create_mock_lora_model(self, base_model, r=8, alpha=16): + """Create a mock LoRA model wrapping the base model""" + mock_lora_model = Mock() + mock_lora_model.base_model = base_model + + mock_lora_model.merge_and_unload = None + mock_lora_model.to = Mock(return_value=mock_lora_model) + + mock_lora_model.generation_config = Mock() + mock_lora_model.config = Mock() + + self.original_q_weight = base_model.q_proj.weight.clone() + self.original_v_weight = base_model.v_proj.weight.clone() + + mock_lora_model.peft_config = {"default": Mock()} + mock_lora_model.peft_config["default"].r = r + mock_lora_model.peft_config["default"].lora_alpha = alpha + + self.lora_A_q = torch.randn( + r, base_model.q_proj.weight.shape[1], dtype=self.dtype + ) + self.lora_B_q = torch.randn( + base_model.q_proj.weight.shape[0], r, dtype=self.dtype + ) + + self.lora_A_v = torch.randn( + r, base_model.v_proj.weight.shape[1], dtype=self.dtype + ) + self.lora_B_v = torch.randn( + base_model.v_proj.weight.shape[0], r, dtype=self.dtype + ) + + self.scaling = alpha / r + + def mock_merge_and_unload(progressbar=False): + """Simulate the actual merge operation""" + # Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling + delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling + delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling + + base_model.q_proj.weight = self.original_q_weight + delta_q + base_model.v_proj.weight = self.original_v_weight + delta_v + + return base_model + + mock_lora_model.merge_and_unload = mock_merge_and_unload + return mock_lora_model + + def test_basic_lora_merge_unmerge_cycle(self): + """Test: original_weights -> merge -> unmerge -> should equal original_weights""" + + base_model = self.create_mock_base_model() + lora_model = self.create_mock_lora_model(base_model) + + original_q_weight = self.original_q_weight.clone() + original_v_weight = self.original_v_weight.clone() + + merged_model = lora_model.merge_and_unload() + + assert not torch.equal(merged_model.q_proj.weight, original_q_weight) + assert not torch.equal(merged_model.v_proj.weight, original_v_weight) + + delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling + delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling + + unmerged_q_weight = merged_model.q_proj.weight - delta_q + unmerged_v_weight = merged_model.v_proj.weight - delta_v + + assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6) + assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6) + + def test_merge_weight_calculation_accuracy(self): + """Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)""" + base_model = self.create_mock_base_model() + lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32) + + expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling + expected_merged_q = self.original_q_weight + expected_delta_q + merged_model = lora_model.merge_and_unload() + + assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6) + + @patch("axolotl.cli.merge_lora.load_model_and_tokenizer") + def test_cli_do_merge_functionality(self, mock_load_model, tmp_path): + base_model = self.create_mock_base_model() + lora_model = self.create_mock_lora_model(base_model) + tokenizer = Mock() + processor = None + + mock_load_model.return_value = (lora_model, tokenizer, processor) + + cfg = DictDefault( + { + "save_safetensors": True, + "torch_dtype": torch.float32, + "local_rank": 0, + "output_dir": str(tmp_path), + } + ) + + with ( + patch("pathlib.Path.mkdir"), + patch.object(base_model, "save_pretrained") as mock_save_model, + patch.object(tokenizer, "save_pretrained") as mock_save_tokenizer, + ): + do_merge_lora(cfg=cfg) + + mock_save_model.assert_called_once() + mock_save_tokenizer.assert_called_once() + + def test_quantized_model_merge_compatibility(self): + """Test 4-bit/8-bit model merging scenarios""" + base_model = self.create_mock_base_model() + + # Mock quantized weights + base_model.q_proj.weight.quant_state = Mock() + base_model.q_proj.weight.quant_state.dtype = torch.uint8 + + lora_model = self.create_mock_lora_model(base_model) + + merged_model = lora_model.merge_and_unload() + assert merged_model is not None + + @patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": ""}) + def test_memory_efficient_merge_with_cpu_offload(self, tmp_path): + """Test lora_on_cpu configuration during merge""" + cfg = DictDefault( + { + "lora_on_cpu": True, + "save_safetensors": True, + "output_dir": str(tmp_path), + "local_rank": 0, + } + ) + + with patch("axolotl.cli.merge_lora.load_model_and_tokenizer") as mock_load: + base_model = self.create_mock_base_model() + lora_model = self.create_mock_lora_model(base_model) + mock_load.return_value = (lora_model, Mock(), None) + + with patch("pathlib.Path.mkdir"), patch("torch.save"): + do_merge_lora(cfg=cfg) + + assert mock_load.called