feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]

* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
This commit is contained in:
Wing Lian
2026-03-22 13:53:19 -04:00
committed by GitHub
parent a67392c427
commit b3289fd190
13 changed files with 2847 additions and 448 deletions

View File

@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
proj.base_layer = base_layer
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)

View File

@@ -102,7 +102,7 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,120 @@
"""Test LoRA kernels under FSDP2 multi-GPU training.
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
lora_embedding_kernel work correctly with FSDP2 sharding, including
with bias, dropout, and DoRA enabled.
"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def _run_training(temp_dir, cfg):
"""Write config and launch multi-GPU training."""
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
def _base_lora_fsdp2_config(temp_dir, **overrides):
"""Base config for LoRA + FSDP2 + kernel tests."""
cfg = {
"base_model": "Qwen/Qwen3-0.6B",
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:1%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
# Enable all LoRA kernels
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_embedding_kernel": True,
"save_safetensors": True,
}
cfg.update(overrides)
return DictDefault(cfg)
class TestFSDP2LoRAKernels:
"""Test LoRA kernels under FSDP2."""
@require_torch_2_7_0
def test_lora_kernels_basic(self, temp_dir):
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
cfg = _base_lora_fsdp2_config(temp_dir)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dropout(self, temp_dir):
"""LoRA kernels + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora(self, temp_dir):
"""LoRA kernels + DoRA + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
"""LoRA kernels + DoRA + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(
temp_dir,
peft_use_dora=True,
lora_dropout=0.05,
)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
test_configs = [
# Dropout prevents patching
# Dropout — kernels now support this
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
# Bias — kernels now support this
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
# Verify patches ARE applied (dropout and bias are now supported)
assert (
layer.forward.__func__ is apply_lora_mlp_swiglu
or layer.forward.__func__ is apply_lora_mlp_geglu
)
def test_kernel_config_options():
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
"""Test model loading with dropout non-zero DOES patch (now supported)."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
# Apply patches — should succeed even with dropout > 0
model_loader.patch_manager._apply_self_attention_lora_patch()
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch(model)
# Verify patch was not applied
# Verify patches WERE applied (dropout is now supported by kernels)
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")
assert hasattr(self_attn, "apply_qkv")
assert hasattr(self_attn, "apply_o")

View File

@@ -28,20 +28,22 @@ class TestLoRAConfigValidation:
result = validate_config(valid_config)
assert result["adapter"] == "lora"
with pytest.raises(ValueError, match="not compatible with DoRA"):
invalid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
# DoRA is now compatible with lora kernels
dora_kernel_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(dora_kernel_config)
assert result["lora_mlp_kernel"] is True
assert result["peft_use_dora"] is True
def test_qlora_4bit_validation(self):
"""Test QLoRA 4-bit configuration validation"""

View File

@@ -38,6 +38,11 @@ class TestLoRAParameterFreezing:
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
mock_layer.lora_B["default"].bias = None
# Required by get_lora_parameters for dropout/DoRA extraction
mock_layer.lora_dropout = {}
mock_layer.lora_magnitude_vector = None
else:
mock_layer.weight = base_layer.weight
mock_layer.bias = base_layer.bias
@@ -48,7 +53,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are disabled."""
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -62,7 +67,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are merged."""
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -77,7 +82,7 @@ class TestLoRAParameterFreezing:
"""Test parameter behavior when no adapters are present."""
layer = self.create_mock_lora_layer(has_adapters=False)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -94,7 +99,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# All parameters should be returned
assert W is not None
@@ -110,7 +115,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Check shape consistency
assert W.shape == (512, 256)
@@ -124,7 +129,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert W.dtype == self.dtype
assert b.dtype == self.dtype
@@ -138,7 +143,7 @@ class TestLoRAParameterFreezing:
quant_state_mock = Mock()
layer.base_layer.weight.quant_state = quant_state_mock
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert quant_state == quant_state_mock
@@ -157,7 +162,7 @@ class TestLoRAParameterFreezing:
layer.active_adapters = ["adapter2"]
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert s == 0.2
assert torch.equal(A, layer.lora_A["adapter2"].weight)
@@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration:
model = get_peft_model(base_model, lora_config)
lora_layer = model.base_model.model.linear
# Test with adapters enabled
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
assert A is not None
assert B is not None
assert s is not None
# Test with adapters disabled
model.disable_adapter_layers()
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
assert A is None
assert B is None
assert s is None