Files
axolotl/tests/e2e/kernels/test_lora_features.py
Wing Lian b3289fd190 feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
2026-03-22 13:53:19 -04:00

1246 lines
41 KiB
Python

"""
Tests for LoRA kernel correctness with bias, dropout, and DoRA support.
Compares fused kernel outputs and gradients against PEFT's reference implementation.
"""
import pytest
import torch
from peft import LoraConfig, get_peft_model
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from axolotl.kernels.lora import (
_compute_dora_scale,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
get_lora_parameters,
matmul_lora,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
MODEL_NAME = "Qwen/Qwen3-0.6B"
DEVICE = "cuda"
DTYPE = torch.bfloat16
@pytest.fixture(scope="module")
def model_config():
return AutoConfig.from_pretrained(MODEL_NAME)
def _make_peft_model(
lora_dropout=0.0,
bias="none",
use_dora=False,
target_modules=None,
):
"""Create a PEFT model with given config."""
if target_modules is None:
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
attn_implementation="eager",
).to(DEVICE)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=lora_dropout,
bias=bias,
use_dora=use_dora,
target_modules=target_modules,
)
peft_model = get_peft_model(model, lora_config)
return peft_model
def _get_layer(peft_model, layer_idx=0):
"""Get a specific transformer layer from the model."""
return peft_model.model.model.layers[layer_idx]
def _make_input(batch=2, seq_len=16, hidden_size=1024):
"""Create random input tensor."""
return torch.randn(
batch, seq_len, hidden_size, dtype=DTYPE, device=DEVICE, requires_grad=True
)
def _compare_tensors(a, b, name="", atol=1e-2, rtol=1e-2):
"""Compare two tensors with informative error messages."""
if a is None and b is None:
return
assert a is not None and b is not None, f"{name}: one is None, other is not"
assert a.shape == b.shape, f"{name}: shape mismatch {a.shape} vs {b.shape}"
diff = (a - b).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
assert torch.allclose(a, b, atol=atol, rtol=rtol), (
f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
)
class TestGetLoraParameters:
"""Test the extended get_lora_parameters function."""
def test_returns_9_values(self):
model = _make_peft_model()
layer = _get_layer(model)
params = get_lora_parameters(layer.self_attn.q_proj)
assert len(params) == 9
W, b, quant, A, B, s, lora_bias, dropout, magnitude = params
assert W is not None
assert A is not None
assert B is not None
assert s is not None
assert lora_bias is None # bias="none"
assert dropout is not None # should be nn.Identity
assert magnitude is None # no DoRA
del model
def test_with_bias(self):
"""Qwen3 has no base bias, so PEFT doesn't add lora_bias even with bias='lora_only'.
This test verifies get_lora_parameters handles this correctly."""
model = _make_peft_model(bias="lora_only")
layer = _get_layer(model)
params = get_lora_parameters(layer.self_attn.q_proj)
_, _, _, _, _, _, lora_bias, _, _ = params
# Qwen3 q_proj has no base bias, so PEFT sets lora_bias=False
assert lora_bias is None
del model
def test_with_bias_on_biased_layer(self):
"""Test with manually added bias to verify lora_bias extraction."""
model = _make_peft_model(bias="lora_only")
layer = _get_layer(model)
q_proj = layer.self_attn.q_proj
adapter = q_proj.active_adapters[0]
# Manually add bias to lora_B to test extraction
old_B = q_proj.lora_B[adapter]
q_proj.lora_B[adapter] = torch.nn.Linear(
old_B.in_features, old_B.out_features, bias=True, device=DEVICE, dtype=DTYPE
)
params = get_lora_parameters(q_proj)
_, _, _, _, _, _, lora_bias, _, _ = params
assert lora_bias is not None
assert lora_bias.shape[0] == old_B.out_features
del model
def test_with_dropout(self):
model = _make_peft_model(lora_dropout=0.1)
layer = _get_layer(model)
params = get_lora_parameters(layer.self_attn.q_proj)
_, _, _, _, _, _, _, dropout, _ = params
assert dropout is not None
assert isinstance(dropout, nn.Dropout)
del model
def test_with_dora(self):
model = _make_peft_model(use_dora=True)
layer = _get_layer(model)
params = get_lora_parameters(layer.self_attn.q_proj)
_, _, _, _, _, _, _, _, magnitude = params
assert magnitude is not None
del model
class TestMatmulLora:
"""Test matmul_lora with new lora_bias and X_drop parameters."""
def test_basic(self):
X = torch.randn(4, 8, dtype=DTYPE, device=DEVICE)
W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE)
A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) # [rank, in]
B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE) # [out, rank]
s = 2.0
result = matmul_lora(X, W, None, None, A, B, s)
expected = X @ W.t() + s * X @ A.t() @ B.t()
_compare_tensors(result, expected, "basic matmul_lora")
def test_with_lora_bias(self):
X = torch.randn(4, 8, dtype=DTYPE, device=DEVICE)
W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE)
A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE)
B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE)
lora_bias = torch.randn(16, dtype=DTYPE, device=DEVICE)
s = 2.0
result = matmul_lora(X, W, None, None, A, B, s, lora_bias=lora_bias)
expected = X @ W.t() + s * X @ A.t() @ B.t() + s * lora_bias
_compare_tensors(result, expected, "matmul_lora with lora_bias")
def test_with_x_drop(self):
X = torch.randn(4, 8, dtype=DTYPE, device=DEVICE)
X_drop = X * 0.5 # simulated dropout
W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE)
A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE)
B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE)
s = 2.0
result = matmul_lora(X, W, None, None, A, B, s, X_drop=X_drop)
expected = X @ W.t() + s * X_drop @ A.t() @ B.t()
_compare_tensors(result, expected, "matmul_lora with X_drop")
class TestDoraScale:
"""Test DoRA magnitude/norm scaling computation."""
def test_basic(self):
W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE)
A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE)
B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE)
magnitude = torch.randn(16, dtype=DTYPE, device=DEVICE).abs() + 0.1
s = 2.0
scale = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE)
# Manual computation
combined = W + s * B @ A
weight_norm = torch.linalg.norm(combined, dim=1)
expected = magnitude / weight_norm
_compare_tensors(scale, expected, "dora_scale")
# ============================================================
# Integration tests: compare kernel outputs against PEFT reference
# ============================================================
def _run_peft_qkv(layer, X):
"""Run Q, K, V projections through PEFT's standard forward."""
Q = layer.self_attn.q_proj(X)
K = layer.self_attn.k_proj(X)
V = layer.self_attn.v_proj(X)
return Q, K, V
def _run_kernel_qkv(layer, X):
"""Run Q, K, V projections through our fused kernel."""
return apply_lora_qkv(layer.self_attn, X, inplace=False)
def _run_peft_o(layer, X):
"""Run O projection through PEFT's standard forward."""
return layer.self_attn.o_proj(X)
def _run_kernel_o(layer, X):
"""Run O projection through our fused kernel."""
return apply_lora_o(layer.self_attn, X)
def _run_peft_mlp(layer, X):
"""Run MLP through PEFT's standard forward."""
return layer.mlp(X)
def _run_kernel_mlp(layer, X):
"""Run MLP through our fused kernel."""
return apply_lora_mlp_swiglu(layer.mlp, X, inplace=False)
class TestQKVKernel:
"""Test LoRA_QKV kernel against PEFT reference."""
@pytest.mark.parametrize("bias", ["none", "lora_only"])
def test_forward_bias(self, bias):
model = _make_peft_model(bias=bias)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X)
kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X)
_compare_tensors(kern_Q, peft_Q, f"QKV Q (bias={bias})")
_compare_tensors(kern_K, peft_K, f"QKV K (bias={bias})")
_compare_tensors(kern_V, peft_V, f"QKV V (bias={bias})")
del model
def test_forward_dropout_eval(self):
"""Dropout disabled in eval - should match exactly."""
model = _make_peft_model(lora_dropout=0.1)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X)
kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X)
_compare_tensors(kern_Q, peft_Q, "QKV Q (dropout eval)")
_compare_tensors(kern_K, peft_K, "QKV K (dropout eval)")
_compare_tensors(kern_V, peft_V, "QKV V (dropout eval)")
del model
def test_forward_dora(self):
model = _make_peft_model(use_dora=True)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X)
kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X)
_compare_tensors(kern_Q, peft_Q, "QKV Q (DoRA)")
_compare_tensors(kern_K, peft_K, "QKV K (DoRA)")
_compare_tensors(kern_V, peft_V, "QKV V (DoRA)")
del model
def test_forward_dora_bias(self):
model = _make_peft_model(use_dora=True, bias="lora_only")
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X)
kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X)
_compare_tensors(kern_Q, peft_Q, "QKV Q (DoRA+bias)")
_compare_tensors(kern_K, peft_K, "QKV K (DoRA+bias)")
_compare_tensors(kern_V, peft_V, "QKV V (DoRA+bias)")
del model
@pytest.mark.parametrize("bias", ["none", "lora_only"])
def test_backward_bias(self, bias):
"""Test that gradients match between kernel and PEFT."""
model = _make_peft_model(bias=bias)
model.train()
layer = _get_layer(model)
# PEFT reference
X1 = _make_input(hidden_size=model.config.hidden_size)
pQ, pK, pV = _run_peft_qkv(layer, X1)
loss_peft = pQ.sum() + pK.sum() + pV.sum()
loss_peft.backward()
peft_grads = {}
for name, param in layer.self_attn.named_parameters():
if param.grad is not None:
peft_grads[name] = param.grad.clone()
layer.self_attn.zero_grad()
# Kernel
X2 = X1.detach().clone().requires_grad_(True)
kQ, kK, kV = _run_kernel_qkv(layer, X2)
loss_kern = kQ.sum() + kK.sum() + kV.sum()
loss_kern.backward()
kern_grads = {}
for name, param in layer.self_attn.named_parameters():
if param.grad is not None:
kern_grads[name] = param.grad.clone()
layer.self_attn.zero_grad()
# Compare LoRA parameter gradients
for name in peft_grads:
if "lora_" in name:
_compare_tensors(
kern_grads.get(name),
peft_grads[name],
f"grad {name} (bias={bias})",
atol=5e-2,
rtol=5e-2,
)
del model
def test_backward_dora(self):
"""Test DoRA backward pass gradients."""
model = _make_peft_model(use_dora=True)
model.train()
layer = _get_layer(model)
X1 = _make_input(hidden_size=model.config.hidden_size)
pQ, pK, pV = _run_peft_qkv(layer, X1)
loss_peft = pQ.sum() + pK.sum() + pV.sum()
loss_peft.backward()
peft_grads = {}
for name, param in layer.self_attn.named_parameters():
if param.grad is not None:
peft_grads[name] = param.grad.clone()
layer.self_attn.zero_grad()
X2 = X1.detach().clone().requires_grad_(True)
kQ, kK, kV = _run_kernel_qkv(layer, X2)
loss_kern = kQ.sum() + kK.sum() + kV.sum()
loss_kern.backward()
kern_grads = {}
for name, param in layer.self_attn.named_parameters():
if param.grad is not None:
kern_grads[name] = param.grad.clone()
layer.self_attn.zero_grad()
for name in peft_grads:
if "lora_" in name or "magnitude" in name:
_compare_tensors(
kern_grads.get(name),
peft_grads[name],
f"grad {name} (DoRA)",
atol=5e-2,
rtol=5e-2,
)
del model
class TestOKernel:
"""Test LoRA_O kernel against PEFT reference."""
@staticmethod
def _o_input_dim(model):
"""o_proj input is num_heads * head_dim (may differ from hidden_size with GQA)."""
cfg = model.config
text_cfg = cfg.get_text_config() if hasattr(cfg, "get_text_config") else cfg
return text_cfg.num_attention_heads * text_cfg.head_dim
@pytest.mark.parametrize("bias", ["none", "lora_only"])
def test_forward_bias(self, bias):
model = _make_peft_model(bias=bias)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=self._o_input_dim(model))
with torch.no_grad():
peft_out = _run_peft_o(layer, X)
kern_out = _run_kernel_o(layer, X)
_compare_tensors(kern_out, peft_out, f"O (bias={bias})")
del model
def test_forward_dora(self):
model = _make_peft_model(use_dora=True)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=self._o_input_dim(model))
with torch.no_grad():
peft_out = _run_peft_o(layer, X)
kern_out = _run_kernel_o(layer, X)
_compare_tensors(kern_out, peft_out, "O (DoRA)")
del model
@pytest.mark.parametrize("bias", ["none", "lora_only"])
def test_backward_bias(self, bias):
model = _make_peft_model(bias=bias)
model.train()
layer = _get_layer(model)
X1 = _make_input(hidden_size=self._o_input_dim(model))
peft_out = _run_peft_o(layer, X1)
peft_out.sum().backward()
peft_grads = {
n: p.grad.clone()
for n, p in layer.self_attn.o_proj.named_parameters()
if p.grad is not None
}
layer.self_attn.o_proj.zero_grad()
X2 = X1.detach().clone().requires_grad_(True)
kern_out = _run_kernel_o(layer, X2)
kern_out.sum().backward()
kern_grads = {
n: p.grad.clone()
for n, p in layer.self_attn.o_proj.named_parameters()
if p.grad is not None
}
layer.self_attn.o_proj.zero_grad()
for name in peft_grads:
if "lora_" in name:
_compare_tensors(
kern_grads.get(name),
peft_grads[name],
f"O grad {name} (bias={bias})",
atol=5e-2,
rtol=5e-2,
)
del model
class TestMLPKernel:
"""Test LoRA_MLP kernel against PEFT reference."""
@pytest.mark.parametrize("bias", ["none", "lora_only"])
def test_forward_bias(self, bias):
model = _make_peft_model(bias=bias)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_out = _run_peft_mlp(layer, X)
kern_out = _run_kernel_mlp(layer, X)
_compare_tensors(kern_out, peft_out, f"MLP (bias={bias})")
del model
def test_forward_dropout_eval(self):
model = _make_peft_model(lora_dropout=0.1)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_out = _run_peft_mlp(layer, X)
kern_out = _run_kernel_mlp(layer, X)
_compare_tensors(kern_out, peft_out, "MLP (dropout eval)")
del model
def test_forward_dora(self):
model = _make_peft_model(use_dora=True)
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_out = _run_peft_mlp(layer, X)
kern_out = _run_kernel_mlp(layer, X)
# Relaxed tolerance for MLP DoRA: 3 projections + activation + DoRA
# causes bf16 accumulation differences
_compare_tensors(kern_out, peft_out, "MLP (DoRA)", atol=0.3, rtol=0.05)
del model
def test_forward_dora_bias(self):
model = _make_peft_model(use_dora=True, bias="lora_only")
model.eval()
layer = _get_layer(model)
X = _make_input(hidden_size=model.config.hidden_size)
with torch.no_grad():
peft_out = _run_peft_mlp(layer, X)
kern_out = _run_kernel_mlp(layer, X)
_compare_tensors(kern_out, peft_out, "MLP (DoRA+bias)", atol=0.3, rtol=0.05)
del model
@pytest.mark.parametrize("bias", ["none", "lora_only"])
def test_backward_bias(self, bias):
model = _make_peft_model(bias=bias)
model.train()
layer = _get_layer(model)
hidden_size = model.config.hidden_size
X1 = _make_input(hidden_size=hidden_size)
peft_out = _run_peft_mlp(layer, X1)
peft_out.sum().backward()
peft_grads = {
n: p.grad.clone()
for n, p in layer.mlp.named_parameters()
if p.grad is not None
}
layer.mlp.zero_grad()
X2 = X1.detach().clone().requires_grad_(True)
kern_out = _run_kernel_mlp(layer, X2)
kern_out.sum().backward()
kern_grads = {
n: p.grad.clone()
for n, p in layer.mlp.named_parameters()
if p.grad is not None
}
layer.mlp.zero_grad()
# MLP backward has longer chain (3 projections + activation) = more bf16 accumulation error
for name in peft_grads:
if "lora_" in name:
_compare_tensors(
kern_grads.get(name),
peft_grads[name],
f"MLP grad {name} (bias={bias})",
atol=0.5,
rtol=0.1,
)
del model
def test_backward_dora(self):
model = _make_peft_model(use_dora=True)
model.train()
layer = _get_layer(model)
X1 = _make_input(hidden_size=model.config.hidden_size)
peft_out = _run_peft_mlp(layer, X1)
peft_out.sum().backward()
peft_grads = {
n: p.grad.clone()
for n, p in layer.mlp.named_parameters()
if p.grad is not None
}
layer.mlp.zero_grad()
X2 = X1.detach().clone().requires_grad_(True)
kern_out = _run_kernel_mlp(layer, X2)
kern_out.sum().backward()
kern_grads = {
n: p.grad.clone()
for n, p in layer.mlp.named_parameters()
if p.grad is not None
}
layer.mlp.zero_grad()
for name in peft_grads:
if "lora_" in name or "magnitude" in name:
_compare_tensors(
kern_grads.get(name),
peft_grads[name],
f"MLP grad {name} (DoRA)",
atol=0.5,
rtol=0.1,
)
del model
class TestFullModelPatch:
"""Test applying kernel patches to a full model."""
def test_patched_forward_basic(self):
"""Test that patched model forward matches unpatched PEFT model (bias=none, no DoRA)."""
from peft import PeftModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
attn_implementation="eager",
).to(DEVICE)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
bias="none",
use_dora=False,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
model = PeftModelForCausalLM(base_model, lora_config)
model.eval()
# Get PEFT reference output
input_ids = torch.randint(0, 1000, (1, 32), device=DEVICE)
with torch.no_grad():
peft_out = model(input_ids).logits
# Apply kernel patches
cfg = DictDefault(
{
"base_model": MODEL_NAME,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_mlp_kernel": True,
}
)
patch_self_attn_lora(cfg)
apply_lora_kernel_patches(model, cfg)
# Get kernel output
with torch.no_grad():
kern_out = model(input_ids).logits
_compare_tensors(kern_out, peft_out, "Full model (basic)", atol=5e-1, rtol=1e-1)
del model
class TestEmbeddingKernel:
"""Test LoRA embedding kernel against PEFT reference."""
def _make_embedding_model(self, use_dora=False):
from peft import PeftModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
attn_implementation="eager",
).to(DEVICE)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
use_dora=use_dora,
target_modules=["embed_tokens"],
)
return PeftModelForCausalLM(model, lora_config)
def test_forward_basic(self):
from axolotl.kernels.lora import apply_lora_embedding
model = self._make_embedding_model()
model.eval()
embed = model.model.model.embed_tokens
input_ids = torch.randint(0, 1000, (2, 16), device=DEVICE)
with torch.no_grad():
peft_out = embed(input_ids)
kern_out = apply_lora_embedding(embed, input_ids)
# Cast to same dtype for comparison (PEFT may return float32)
_compare_tensors(kern_out.to(peft_out.dtype), peft_out, "Embedding basic")
del model
def test_forward_dora(self):
from axolotl.kernels.lora import apply_lora_embedding
model = self._make_embedding_model(use_dora=True)
model.eval()
embed = model.model.model.embed_tokens
input_ids = torch.randint(0, 1000, (2, 16), device=DEVICE)
with torch.no_grad():
peft_out = embed(input_ids)
kern_out = apply_lora_embedding(embed, input_ids)
_compare_tensors(
kern_out.to(peft_out.dtype), peft_out, "Embedding DoRA", atol=0.3, rtol=0.05
)
del model
def test_backward(self):
from axolotl.kernels.lora import apply_lora_embedding
model = self._make_embedding_model()
model.train()
embed = model.model.model.embed_tokens
input_ids = torch.randint(0, 1000, (2, 16), device=DEVICE)
# PEFT reference
peft_out = embed(input_ids)
peft_out.sum().backward()
peft_grads = {}
for n, p in embed.named_parameters():
if p.grad is not None and "lora" in n:
peft_grads[n] = p.grad.clone()
embed.zero_grad()
# Kernel
kern_out = apply_lora_embedding(embed, input_ids)
kern_out.sum().backward()
kern_grads = {}
for n, p in embed.named_parameters():
if p.grad is not None and "lora" in n:
kern_grads[n] = p.grad.clone()
embed.zero_grad()
for name in peft_grads:
_compare_tensors(
kern_grads.get(name),
peft_grads[name],
f"Embedding grad {name}",
atol=5e-2,
rtol=5e-2,
)
del model
class TestTiedEmbeddings:
"""Test that tied embeddings work correctly with kernel patching."""
def test_tied_embed_and_lm_head(self):
"""When both embed_tokens and lm_head have LoRA, PEFT unties them.
Verify patched model produces valid output (no crashes, finite values)."""
from peft import PeftModelForCausalLM
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
attn_implementation="eager",
).to(DEVICE)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=[
"embed_tokens",
"lm_head",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
model = PeftModelForCausalLM(base, lora_config)
model.eval()
cfg = DictDefault(
{
"base_model": MODEL_NAME,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_mlp_kernel": True,
"lora_embedding_kernel": True,
}
)
# Apply all kernel patches (class + instance level)
patch_self_attn_lora(cfg)
apply_lora_kernel_patches(model, cfg)
input_ids = torch.randint(0, 1000, (1, 32), device=DEVICE)
with torch.no_grad():
out = model(input_ids).logits
# Verify output is valid
assert out.shape == (1, 32, model.config.vocab_size)
assert torch.isfinite(out).all(), "Output contains non-finite values"
assert out.abs().max() > 0, "Output is all zeros"
# Verify backward works
model.train()
out = model(input_ids).logits
out.sum().backward()
# Check that LoRA params got gradients
embed = model.model.model.embed_tokens
has_embed_grad = any(
p.grad is not None and p.grad.abs().sum() > 0
for n, p in embed.named_parameters()
if "lora" in n
)
assert has_embed_grad, "Embedding LoRA params got no gradients"
del model
class TestQuantizedModels:
"""Test kernels with quantized base weights."""
def test_nf4_qlora_forward_backward(self):
"""NF4 QLoRA with kernel patches."""
from peft import PeftModelForCausalLM
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=DTYPE,
bnb_4bit_use_double_quant=True,
)
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
attn_implementation="eager",
)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
model = PeftModelForCausalLM(base, lora_config)
cfg = DictDefault(
{
"base_model": MODEL_NAME,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_mlp_kernel": True,
}
)
patch_self_attn_lora(cfg)
apply_lora_kernel_patches(model, cfg)
model.train()
ids = torch.randint(0, 1000, (1, 32), device=DEVICE)
out = model(ids).logits
assert torch.isfinite(out).all()
out.sum().backward()
has_grads = sum(
1 for n, p in model.named_parameters() if p.grad is not None and "lora" in n
)
assert has_grads > 0, "No LoRA gradients"
del model
def test_nf4_single_quant(self):
"""NF4 without double quantization."""
from peft import PeftModelForCausalLM
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=DTYPE,
)
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
attn_implementation="eager",
)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
model = PeftModelForCausalLM(base, lora_config)
cfg = DictDefault(
{
"base_model": MODEL_NAME,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_mlp_kernel": True,
}
)
patch_self_attn_lora(cfg)
apply_lora_kernel_patches(model, cfg)
model.train()
ids = torch.randint(0, 1000, (1, 32), device=DEVICE)
out = model(ids).logits
assert torch.isfinite(out).all()
out.sum().backward()
has_grads = sum(
1 for n, p in model.named_parameters() if p.grad is not None and "lora" in n
)
assert has_grads > 0
del model
class TestTritonDoRA:
"""Test Triton DoRA kernel against reference implementation."""
def test_triton_dora_scale(self):
from axolotl.kernels.dora import triton_dora_scale
from axolotl.kernels.lora import _compute_dora_scale
# Random weights matching Qwen3-1.7B dimensions
out_feat, in_feat, rank = 1024, 1024, 8
W = torch.randn(out_feat, in_feat, dtype=DTYPE, device=DEVICE)
A = torch.randn(rank, in_feat, dtype=DTYPE, device=DEVICE)
B = torch.randn(out_feat, rank, dtype=DTYPE, device=DEVICE)
magnitude = torch.randn(out_feat, dtype=DTYPE, device=DEVICE).abs() + 0.1
s = 2.0
# Clear cache to force recomputation
if hasattr(magnitude, "_dora_cache"):
del magnitude._dora_cache
ref = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE)
tri = triton_dora_scale(W, None, A, B, s, magnitude, DTYPE)
_compare_tensors(tri, ref, "Triton DoRA scale", atol=1e-2, rtol=1e-2)
def test_triton_dora_scale_small(self):
"""Test with K/V projection dimensions (smaller out_features)."""
from axolotl.kernels.dora import triton_dora_scale
from axolotl.kernels.lora import _compute_dora_scale
out_feat, in_feat, rank = 128, 1024, 8
W = torch.randn(out_feat, in_feat, dtype=DTYPE, device=DEVICE)
A = torch.randn(rank, in_feat, dtype=DTYPE, device=DEVICE)
B = torch.randn(out_feat, rank, dtype=DTYPE, device=DEVICE)
magnitude = torch.randn(out_feat, dtype=DTYPE, device=DEVICE).abs() + 0.1
s = 2.0
if hasattr(magnitude, "_dora_cache"):
del magnitude._dora_cache
ref = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE)
tri = triton_dora_scale(W, None, A, B, s, magnitude, DTYPE)
_compare_tensors(tri, ref, "Triton DoRA scale (small)", atol=1e-2, rtol=1e-2)
# ============================================================
# Regression tests for review fixes
# ============================================================
class TestDoRAEmbeddingNoDoubleScale:
"""Regression: DoRA embedding forward must save the pre-scaled combined
tensor, not the already-scaled result, so backward computes d_mag correctly."""
def test_dora_magnitude_gradient_magnitude(self):
"""d_mag should be O(1) relative to the gradient, not O(mag_scale^2)."""
from peft import PeftModelForCausalLM
from axolotl.kernels.lora import apply_lora_embedding
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
attn_implementation="eager",
).to(DEVICE)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
use_dora=True,
target_modules=["embed_tokens"],
)
model = PeftModelForCausalLM(base, lora_config)
model.train()
embed = model.model.model.embed_tokens
ids = torch.randint(0, 1000, (2, 16), device=DEVICE)
# Run PEFT reference to get reference d_mag
peft_out = embed(ids)
peft_out.sum().backward()
peft_mag_grad = None
for n, p in embed.named_parameters():
if "magnitude" in n and p.grad is not None:
peft_mag_grad = p.grad.clone()
embed.zero_grad()
# Run kernel
kern_out = apply_lora_embedding(embed, ids)
kern_out.to(peft_out.dtype).sum().backward()
kern_mag_grad = None
for n, p in embed.named_parameters():
if "magnitude" in n and p.grad is not None:
kern_mag_grad = p.grad.clone()
embed.zero_grad()
assert peft_mag_grad is not None, "PEFT should produce magnitude gradients"
assert kern_mag_grad is not None, "Kernel should produce magnitude gradients"
# Key check: gradients should be same order of magnitude
# Double-scaling would make kern_mag_grad ~mag_scale times too large
ratio = kern_mag_grad.abs().mean() / peft_mag_grad.abs().mean()
assert 0.5 < ratio < 2.0, (
f"Magnitude gradient ratio kernel/peft = {ratio:.3f}, "
f"expected ~1.0 (double-scaling would give >> 1)"
)
del model
class TestDoraCacheInvalidation:
"""Regression: DoRA weight norm cache must invalidate after in-place
param updates (optimizer steps), not just pointer changes."""
def test_cache_invalidates_on_inplace_update(self):
W = torch.randn(64, 64, dtype=DTYPE, device=DEVICE)
A = torch.randn(8, 64, dtype=DTYPE, device=DEVICE)
B = torch.randn(64, 8, dtype=DTYPE, device=DEVICE)
magnitude = torch.randn(64, dtype=DTYPE, device=DEVICE).abs() + 0.1
s = 2.0
# Clear any existing cache
if hasattr(magnitude, "_dora_cache"):
del magnitude._dora_cache
# First call populates cache
result1 = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE)
# Simulate optimizer in-place update (pointer stays same, content changes)
old_ptr = A.data_ptr()
A.data.add_(torch.randn_like(A) * 0.1)
assert A.data_ptr() == old_ptr, "Pointer should not change for in-place ops"
# Second call must detect the change and recompute
result2 = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE)
# Results should differ since A changed
assert not torch.allclose(result1, result2, atol=1e-4), (
"DoRA scale should change after in-place param update — cache not invalidated!"
)
class TestEmbeddingPaddingIdxGrad:
"""Regression: custom embedding backward must zero out gradients at
padding_idx positions, matching F.embedding behavior."""
def test_padding_idx_gradient_is_zero(self):
from axolotl.kernels.lora import LoRA_Embedding
vocab, hidden, rank = 100, 32, 4
W = torch.randn(
vocab, hidden, dtype=torch.float32, device=DEVICE, requires_grad=False
)
A = torch.randn(
rank, vocab, dtype=torch.float32, device=DEVICE, requires_grad=True
)
B = torch.randn(
hidden, rank, dtype=torch.float32, device=DEVICE, requires_grad=True
)
s = 2.0
padding_idx = 0
# Input containing the padding token
x = torch.tensor([[padding_idx, 1, 2, padding_idx, 3]], device=DEVICE)
out = LoRA_Embedding.apply(
x,
W,
A,
B,
s,
None,
padding_idx,
None,
2.0,
False,
False, # max_norm, norm_type, scale_grad_by_freq, sparse
)
out.sum().backward()
# The gradient for A at the padding_idx column should be zero
# A is [rank, vocab], so A.grad[:, padding_idx] should be zero
assert A.grad is not None
pad_grad = A.grad[:, padding_idx]
assert torch.all(pad_grad == 0), (
f"Gradient at padding_idx={padding_idx} should be zero, got {pad_grad}"
)
# Non-padding positions should have non-zero gradients
non_pad_grad = A.grad[:, 1]
assert non_pad_grad.abs().sum() > 0, "Non-padding gradients should be non-zero"
class TestEmbeddingScaleGradByFreq:
"""Regression: custom embedding backward must scale gradients by
inverse frequency when scale_grad_by_freq=True."""
def test_repeated_tokens_get_scaled_gradients(self):
from axolotl.kernels.lora import LoRA_Embedding
vocab, hidden, rank = 100, 32, 4
W = torch.randn(
vocab, hidden, dtype=torch.float32, device=DEVICE, requires_grad=False
)
# Run WITHOUT scale_grad_by_freq
A1 = torch.randn(
rank, vocab, dtype=torch.float32, device=DEVICE, requires_grad=True
)
B1 = torch.randn(
hidden, rank, dtype=torch.float32, device=DEVICE, requires_grad=True
)
# Token 5 appears 3 times
x = torch.tensor([[5, 5, 5, 10, 20]], device=DEVICE)
out1 = LoRA_Embedding.apply(
x,
W,
A1,
B1,
2.0,
None,
None,
None,
2.0,
False,
False,
)
out1.sum().backward()
grad_no_scale = A1.grad[:, 5].clone()
# Run WITH scale_grad_by_freq
A2 = A1.data.clone().requires_grad_(True)
B2 = B1.data.clone().requires_grad_(True)
out2 = LoRA_Embedding.apply(
x,
W,
A2,
B2,
2.0,
None,
None,
None,
2.0,
True,
False,
)
out2.sum().backward()
grad_with_scale = A2.grad[:, 5].clone()
# With scale_grad_by_freq, token 5 (count=3) should have grad / 3
expected_ratio = 1.0 / 3.0
actual_ratio = grad_with_scale.abs().mean() / grad_no_scale.abs().mean()
assert abs(actual_ratio - expected_ratio) < 0.01, (
f"scale_grad_by_freq ratio for count=3 token: expected {expected_ratio:.3f}, "
f"got {actual_ratio:.3f}"
)
class TestEmbeddingDropoutNotAppliedToBase:
"""Regression: embedding dropout must NOT be applied to the base embedding
output — PEFT's Embedding.forward does not use lora_dropout."""
def test_kernel_matches_peft_with_dropout_config(self):
"""Even with lora_dropout>0, embedding output should match PEFT exactly."""
from peft import PeftModelForCausalLM
from axolotl.kernels.lora import apply_lora_embedding
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=DTYPE,
attn_implementation="eager",
).to(DEVICE)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.5, # high dropout
target_modules=["embed_tokens"],
)
model = PeftModelForCausalLM(base, lora_config)
model.train() # training mode — dropout would be active if applied
embed = model.model.model.embed_tokens
ids = torch.randint(0, 1000, (2, 16), device=DEVICE)
# Run both multiple times — if dropout were applied, results would vary
with torch.no_grad():
peft_out = embed(ids)
kern1 = apply_lora_embedding(embed, ids)
kern2 = apply_lora_embedding(embed, ids)
# Kernel should be deterministic (no dropout)
_compare_tensors(
kern1.to(peft_out.dtype),
kern2.to(peft_out.dtype),
"Embedding deterministic (no dropout)",
atol=0,
rtol=0,
)
# And should match PEFT
_compare_tensors(
kern1.to(peft_out.dtype),
peft_out,
"Embedding matches PEFT with dropout config",
)
del model