feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]

* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
This commit is contained in:
Wing Lian
2026-03-22 13:53:19 -04:00
committed by GitHub
parent a67392c427
commit b3289fd190
13 changed files with 2847 additions and 448 deletions

147
src/axolotl/kernels/dora.py Normal file
View File

@@ -0,0 +1,147 @@
"""
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
Fuses the weight norm computation and magnitude scaling to avoid
materializing the full [out_features, in_features] combined weight matrix.
The B@A product is computed row-by-row inside the kernel.
"""
import torch
import triton
import triton.language as tl
from .quantize import dequantize
@triton.jit
def _dora_fused_norm_kernel(
# Pointers
W_ptr, # base weight [out, in] (dequantized, row-major)
B_ptr, # LoRA B [out, rank] (row-major)
A_ptr, # LoRA A [rank, in] (row-major)
mag_ptr, # magnitude vector [out]
out_ptr, # output mag_norm_scale [out]
# Shapes
out_features,
in_features,
rank,
# Scaling
lora_scale, # float scaling factor
# Block sizes
BLOCK_IN: tl.constexpr,
BLOCK_R: tl.constexpr, # >= rank, power of 2
):
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
Each program handles one output row. B[row,:] is loaded once (small),
then we tile over in_features computing the dot product with A[:,tile]
and accumulating the squared norm.
This avoids materializing the full [out, in] B@A matrix.
"""
row = tl.program_id(0)
if row >= out_features:
return
# Accumulate squared norm across tiles of in_features
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
for start in range(0, in_features, BLOCK_IN):
cols = start + tl.arange(0, BLOCK_IN)
col_mask = cols < in_features
# Load W[row, cols]
w_vals = tl.load(
W_ptr + row * in_features + cols,
mask=col_mask,
other=0.0,
).to(tl.float32)
# Compute (B[row,:] @ A[:, cols]) for this tile
# Load B[row, r] as scalar and A[r, cols] as vector for each r
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
for r in tl.static_range(BLOCK_R):
# Load scalar B[row, r]
b_val = tl.load(
B_ptr + row * rank + r,
mask=(r < rank),
other=0.0,
).to(tl.float32)
# Load vector A[r, cols]
a_vals = tl.load(
A_ptr + r * in_features + cols,
mask=(col_mask & (r < rank)),
other=0.0,
).to(tl.float32)
ba_vals += b_val * a_vals
# Combined: W + s * (B @ A)
combined = w_vals + lora_scale * ba_vals
# Accumulate squared values
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
# Reduce to scalar norm
norm_sq = tl.sum(norm_sq_acc, axis=0)
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
# Load magnitude and compute scale
mag = tl.load(mag_ptr + row).to(tl.float32)
scale = mag / norm
tl.store(out_ptr + row, scale)
def triton_dora_scale(
W: torch.Tensor,
W_quant,
A: torch.Tensor,
B: torch.Tensor,
s: float,
magnitude: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute DoRA mag_norm_scale using fused Triton kernel.
Computes B@A row-by-row inside the kernel, avoiding the full
[out_features, in_features] materialization.
Args:
W: base weight [out, in] (possibly quantized)
W_quant: quantization state
A: LoRA A [rank, in]
B: LoRA B [out, rank]
s: LoRA scaling factor
magnitude: learned magnitude [out]
dtype: compute dtype
Returns:
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
"""
# Dequantize W to [out, in]
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
out_features, in_features = W_full.shape
rank = A.shape[0]
out = torch.empty(out_features, dtype=dtype, device=W.device)
# Block sizes
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
BLOCK_R = triton.next_power_of_2(rank)
_dora_fused_norm_kernel[(out_features,)](
W_full,
B.contiguous().to(dtype),
A.contiguous().to(dtype),
magnitude.contiguous(),
out,
out_features=out_features,
in_features=in_features,
rank=rank,
lora_scale=s,
BLOCK_IN=BLOCK_IN,
BLOCK_R=BLOCK_R,
)
return out.detach()

File diff suppressed because it is too large Load Diff

View File

@@ -105,6 +105,10 @@ def dequantize(
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
# Non-double-quantized models have offset=None and state2=None
if quant_state.offset is None or quant_state.state2 is None:
# Fall back to bitsandbytes standard dequantize
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype

View File

@@ -12,6 +12,7 @@ from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_embedding,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
@@ -370,13 +371,13 @@ def apply_lora_kernel_patches(
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# Log what features are active
if lora_config.lora_dropout > 0:
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
if lora_config.bias != "none":
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
if lora_config.use_dora:
LOG.info("LoRA kernels: DoRA enabled")
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
@@ -419,44 +420,33 @@ def apply_lora_kernel_patches(
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention QKV projections - requires LoRA adapters"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention output projection - requires LoRA adapters"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
@@ -464,15 +454,50 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
"Cannot patch some MLP layers - requires LoRA adapters"
)
# Patch embedding layers (model-level, not per-layer)
if cfg.lora_embedding_kernel:
_patch_embedding_layers(model, cfg)
LOG.setLevel(original_level)
return model
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
"""Patch embedding layers with fused LoRA kernel.
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
"""
pretrained_model = model.model
patched = 0
# Find embedding modules - check common locations
for attr_path in [
("model", "embed_tokens"),
("model", "language_model", "embed_tokens"),
]:
parent = pretrained_model
for attr in attr_path:
parent = getattr(parent, attr, None)
if parent is None:
break
if parent is not None and hasattr(parent, "lora_embedding_A"):
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
parent.forward = types.MethodType(apply_lora_embedding, parent)
patched += 1
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
# when included in target_modules. No special embedding handling needed since
# PEFT wraps it as a Linear (not Embedding) even for tied models.
if not patched:
LOG.debug("No embedding layers with LoRA found to patch")
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching

View File

@@ -703,6 +703,12 @@ class AxolotlInputConfig(
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_embedding_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
@@ -1313,6 +1319,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
or data.get("lora_embedding_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp_config") is not None
@@ -1360,7 +1367,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
"lora_o_kernel",
"lora_embedding_kernel",
]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
@@ -1373,10 +1385,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
@@ -1398,6 +1406,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
if data.get("lora_embedding_kernel") is None:
data["lora_embedding_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "

View File

@@ -681,15 +681,7 @@ class LoRAValidationMixin:
@model_validator(mode="before")
@classmethod
def check_lora_kernels_dora(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("peft_use_dora"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with DoRA at the moment."
)
# DoRA is now supported by lora kernels
return data
@model_validator(mode="before")