* config_val tests * remove config val(not needed) * config validation * parameter freeze validation * merge/unmerge tests * removal unwanted * rename * lint * updated lint * Update tests/utils/lora/test_config_validation_lora.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * pytest skip + mock fix * nitpicks * revert some nitpicks --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
262 lines
8.8 KiB
Python
262 lines
8.8 KiB
Python
import importlib.util
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from axolotl.kernels.lora import get_lora_parameters
|
|
|
|
PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None
|
|
|
|
|
|
class TestLoRAParameterFreezing:
|
|
"""Test suite for LoRA parameter freezing validation."""
|
|
|
|
def setup_method(self):
|
|
self.dtype = torch.float32
|
|
|
|
def create_mock_lora_layer(
|
|
self, has_adapters=True, adapters_disabled=False, merged=False
|
|
):
|
|
"""Create a mock LoRA layer for testing."""
|
|
mock_layer = Mock()
|
|
|
|
base_layer = Mock()
|
|
base_layer.weight = torch.randn(512, 256, dtype=self.dtype)
|
|
base_layer.bias = torch.randn(512, dtype=self.dtype)
|
|
|
|
if has_adapters:
|
|
mock_layer.base_layer = base_layer
|
|
mock_layer.disable_adapters = adapters_disabled
|
|
mock_layer.merged = merged
|
|
|
|
mock_layer.active_adapters = ["default"]
|
|
mock_layer.lora_A = {"default": Mock()}
|
|
mock_layer.lora_B = {"default": Mock()}
|
|
mock_layer.scaling = {"default": 0.1}
|
|
|
|
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
|
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
|
else:
|
|
mock_layer.weight = base_layer.weight
|
|
mock_layer.bias = base_layer.bias
|
|
|
|
return mock_layer
|
|
|
|
def test_parameter_freezing_adapters_disabled(self):
|
|
"""Test that LoRA parameters are None when adapters are disabled."""
|
|
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
# Base parameters should be returned
|
|
assert W is not None
|
|
assert b is not None
|
|
# LoRA parameters should be None (frozen)
|
|
assert A is None
|
|
assert B is None
|
|
assert s is None
|
|
|
|
def test_parameter_freezing_adapters_merged(self):
|
|
"""Test that LoRA parameters are None when adapters are merged."""
|
|
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
# Base parameters should be returned
|
|
assert W is not None
|
|
assert b is not None
|
|
|
|
# LoRA parameters should be None (frozen)
|
|
assert A is None
|
|
assert B is None
|
|
assert s is None
|
|
|
|
def test_parameter_freezing_no_adapters(self):
|
|
"""Test parameter behavior when no adapters are present."""
|
|
layer = self.create_mock_lora_layer(has_adapters=False)
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
# Base parameters should be returned
|
|
assert W is not None
|
|
assert b is not None
|
|
|
|
# LoRA parameters should be None (frozen)
|
|
assert A is None
|
|
assert B is None
|
|
assert s is None
|
|
|
|
def test_parameter_active_adapters_enabled(self):
|
|
"""Test that LoRA parameters are returned when adapters are active."""
|
|
layer = self.create_mock_lora_layer(
|
|
has_adapters=True, adapters_disabled=False, merged=False
|
|
)
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
# All parameters should be returned
|
|
assert W is not None
|
|
assert b is not None
|
|
assert A is not None
|
|
assert B is not None
|
|
assert s is not None
|
|
assert s == 0.1
|
|
|
|
def test_parameter_shapes_consistency(self):
|
|
"""Test that parameter shapes are consistent when active."""
|
|
layer = self.create_mock_lora_layer(
|
|
has_adapters=True, adapters_disabled=False, merged=False
|
|
)
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
# Check shape consistency
|
|
assert W.shape == (512, 256)
|
|
assert b.shape == (512,)
|
|
assert A.shape == (16, 256)
|
|
assert B.shape == (512, 16)
|
|
|
|
def test_parameter_dtypes_consistency(self):
|
|
"""Test that parameter dtypes are consistent."""
|
|
layer = self.create_mock_lora_layer(
|
|
has_adapters=True, adapters_disabled=False, merged=False
|
|
)
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
assert W.dtype == self.dtype
|
|
assert b.dtype == self.dtype
|
|
assert A.dtype == self.dtype
|
|
assert B.dtype == self.dtype
|
|
|
|
def test_quantization_state_handling(self):
|
|
"""Test that quantization state is properly handled."""
|
|
layer = self.create_mock_lora_layer(has_adapters=True)
|
|
|
|
quant_state_mock = Mock()
|
|
layer.base_layer.weight.quant_state = quant_state_mock
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
assert quant_state == quant_state_mock
|
|
|
|
def test_multiple_adapters_active_adapter_selection(self):
|
|
"""Test that the correct adapter is selected when multiple adapters exist."""
|
|
layer = self.create_mock_lora_layer(
|
|
has_adapters=True, adapters_disabled=False, merged=False
|
|
)
|
|
|
|
layer.lora_A["adapter2"] = Mock()
|
|
layer.lora_B["adapter2"] = Mock()
|
|
layer.scaling["adapter2"] = 0.2
|
|
|
|
layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype)
|
|
layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype)
|
|
|
|
layer.active_adapters = ["adapter2"]
|
|
|
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
|
|
|
assert s == 0.2
|
|
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
|
assert torch.equal(B, layer.lora_B["adapter2"].weight)
|
|
|
|
|
|
class TestLoRAParameterFreezingIntegration:
|
|
"""Integration tests for parameter freezing with actual LoRA layers."""
|
|
|
|
@pytest.mark.skipif(
|
|
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
|
)
|
|
def test_parameter_freezing_with_real_lora_layer(self):
|
|
"""Test parameter freezing with actual PEFT LoRA layer."""
|
|
from peft import LoraConfig, get_peft_model
|
|
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(256, 512)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
base_model = SimpleModel()
|
|
lora_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
target_modules=["linear"],
|
|
lora_dropout=0.1,
|
|
)
|
|
model = get_peft_model(base_model, lora_config)
|
|
lora_layer = model.base_model.model.linear
|
|
# Test with adapters enabled
|
|
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
|
assert A is not None
|
|
assert B is not None
|
|
assert s is not None
|
|
# Test with adapters disabled
|
|
model.disable_adapter_layers()
|
|
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
|
assert A is None
|
|
assert B is None
|
|
assert s is None
|
|
|
|
@pytest.mark.skipif(
|
|
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
|
)
|
|
def test_parameter_freezing_gradient_behavior(self):
|
|
"""Test that frozen parameters don't receive gradients."""
|
|
from peft import LoraConfig, get_peft_model
|
|
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(256, 512)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
base_model = SimpleModel()
|
|
lora_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
target_modules=["linear"],
|
|
lora_dropout=0.1,
|
|
)
|
|
model = get_peft_model(base_model, lora_config)
|
|
x = torch.randn(1, 256)
|
|
target = torch.randn(1, 512)
|
|
model.enable_adapter_layers()
|
|
output = model(x)
|
|
loss = nn.MSELoss()(output, target)
|
|
loss.backward()
|
|
lora_layer = model.base_model.model.linear
|
|
has_lora_grads = any(
|
|
param.grad is not None
|
|
for name, param in lora_layer.named_parameters()
|
|
if "lora_" in name
|
|
)
|
|
assert has_lora_grads, (
|
|
"LoRA parameters should have gradients when adapters are enabled"
|
|
)
|
|
model.zero_grad()
|
|
model.disable_adapter_layers()
|
|
output = model(x)
|
|
loss = nn.MSELoss()(output, target)
|
|
any_requires_grad = any(param.requires_grad for param in model.parameters())
|
|
if any_requires_grad:
|
|
loss.backward()
|
|
has_lora_grads_disabled = any(
|
|
param.grad is not None
|
|
for name, param in lora_layer.named_parameters()
|
|
if "lora_" in name
|
|
)
|
|
assert not has_lora_grads_disabled, (
|
|
"LoRA parameters should not have gradients when adapters are disabled"
|
|
)
|
|
model.zero_grad()
|
|
del model, base_model, lora_layer, x, target, output, loss
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|