Increased test coverage for lora/qlora (#3147)
* config_val tests * remove config val(not needed) * config validation * parameter freeze validation * merge/unmerge tests * removal unwanted * rename * lint * updated lint * Update tests/utils/lora/test_config_validation_lora.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * pytest skip + mock fix * nitpicks * revert some nitpicks --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
261
tests/utils/lora/test_freeze_lora.py
Normal file
261
tests/utils/lora/test_freeze_lora.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import importlib.util
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None
|
||||
|
||||
|
||||
class TestLoRAParameterFreezing:
|
||||
"""Test suite for LoRA parameter freezing validation."""
|
||||
|
||||
def setup_method(self):
|
||||
self.dtype = torch.float32
|
||||
|
||||
def create_mock_lora_layer(
|
||||
self, has_adapters=True, adapters_disabled=False, merged=False
|
||||
):
|
||||
"""Create a mock LoRA layer for testing."""
|
||||
mock_layer = Mock()
|
||||
|
||||
base_layer = Mock()
|
||||
base_layer.weight = torch.randn(512, 256, dtype=self.dtype)
|
||||
base_layer.bias = torch.randn(512, dtype=self.dtype)
|
||||
|
||||
if has_adapters:
|
||||
mock_layer.base_layer = base_layer
|
||||
mock_layer.disable_adapters = adapters_disabled
|
||||
mock_layer.merged = merged
|
||||
|
||||
mock_layer.active_adapters = ["default"]
|
||||
mock_layer.lora_A = {"default": Mock()}
|
||||
mock_layer.lora_B = {"default": Mock()}
|
||||
mock_layer.scaling = {"default": 0.1}
|
||||
|
||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
else:
|
||||
mock_layer.weight = base_layer.weight
|
||||
mock_layer.bias = base_layer.bias
|
||||
|
||||
return mock_layer
|
||||
|
||||
def test_parameter_freezing_adapters_disabled(self):
|
||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_freezing_adapters_merged(self):
|
||||
"""Test that LoRA parameters are None when adapters are merged."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_freezing_no_adapters(self):
|
||||
"""Test parameter behavior when no adapters are present."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_active_adapters_enabled(self):
|
||||
"""Test that LoRA parameters are returned when adapters are active."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# All parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
assert s == 0.1
|
||||
|
||||
def test_parameter_shapes_consistency(self):
|
||||
"""Test that parameter shapes are consistent when active."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Check shape consistency
|
||||
assert W.shape == (512, 256)
|
||||
assert b.shape == (512,)
|
||||
assert A.shape == (16, 256)
|
||||
assert B.shape == (512, 16)
|
||||
|
||||
def test_parameter_dtypes_consistency(self):
|
||||
"""Test that parameter dtypes are consistent."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert W.dtype == self.dtype
|
||||
assert b.dtype == self.dtype
|
||||
assert A.dtype == self.dtype
|
||||
assert B.dtype == self.dtype
|
||||
|
||||
def test_quantization_state_handling(self):
|
||||
"""Test that quantization state is properly handled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True)
|
||||
|
||||
quant_state_mock = Mock()
|
||||
layer.base_layer.weight.quant_state = quant_state_mock
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert quant_state == quant_state_mock
|
||||
|
||||
def test_multiple_adapters_active_adapter_selection(self):
|
||||
"""Test that the correct adapter is selected when multiple adapters exist."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
layer.lora_A["adapter2"] = Mock()
|
||||
layer.lora_B["adapter2"] = Mock()
|
||||
layer.scaling["adapter2"] = 0.2
|
||||
|
||||
layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
|
||||
layer.active_adapters = ["adapter2"]
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert s == 0.2
|
||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||
assert torch.equal(B, layer.lora_B["adapter2"].weight)
|
||||
|
||||
|
||||
class TestLoRAParameterFreezingIntegration:
|
||||
"""Integration tests for parameter freezing with actual LoRA layers."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
||||
)
|
||||
def test_parameter_freezing_with_real_lora_layer(self):
|
||||
"""Test parameter freezing with actual PEFT LoRA layer."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
base_model = SimpleModel()
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["linear"],
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
lora_layer = model.base_model.model.linear
|
||||
# Test with adapters enabled
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
# Test with adapters disabled
|
||||
model.disable_adapter_layers()
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
||||
)
|
||||
def test_parameter_freezing_gradient_behavior(self):
|
||||
"""Test that frozen parameters don't receive gradients."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
base_model = SimpleModel()
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["linear"],
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
x = torch.randn(1, 256)
|
||||
target = torch.randn(1, 512)
|
||||
model.enable_adapter_layers()
|
||||
output = model(x)
|
||||
loss = nn.MSELoss()(output, target)
|
||||
loss.backward()
|
||||
lora_layer = model.base_model.model.linear
|
||||
has_lora_grads = any(
|
||||
param.grad is not None
|
||||
for name, param in lora_layer.named_parameters()
|
||||
if "lora_" in name
|
||||
)
|
||||
assert has_lora_grads, (
|
||||
"LoRA parameters should have gradients when adapters are enabled"
|
||||
)
|
||||
model.zero_grad()
|
||||
model.disable_adapter_layers()
|
||||
output = model(x)
|
||||
loss = nn.MSELoss()(output, target)
|
||||
any_requires_grad = any(param.requires_grad for param in model.parameters())
|
||||
if any_requires_grad:
|
||||
loss.backward()
|
||||
has_lora_grads_disabled = any(
|
||||
param.grad is not None
|
||||
for name, param in lora_layer.named_parameters()
|
||||
if "lora_" in name
|
||||
)
|
||||
assert not has_lora_grads_disabled, (
|
||||
"LoRA parameters should not have gradients when adapters are disabled"
|
||||
)
|
||||
model.zero_grad()
|
||||
del model, base_model, lora_layer, x, target, output, loss
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
Reference in New Issue
Block a user