feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings * chore: lint * chore: lint * address PR feedback, add regression tests, add fsdp2 tests for lora kernels * update tests for new sigs * update tests now that bias and dropout are supported
This commit is contained in:
147
src/axolotl/kernels/dora.py
Normal file
147
src/axolotl/kernels/dora.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||||
|
|
||||||
|
Fuses the weight norm computation and magnitude scaling to avoid
|
||||||
|
materializing the full [out_features, in_features] combined weight matrix.
|
||||||
|
The B@A product is computed row-by-row inside the kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .quantize import dequantize
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _dora_fused_norm_kernel(
|
||||||
|
# Pointers
|
||||||
|
W_ptr, # base weight [out, in] (dequantized, row-major)
|
||||||
|
B_ptr, # LoRA B [out, rank] (row-major)
|
||||||
|
A_ptr, # LoRA A [rank, in] (row-major)
|
||||||
|
mag_ptr, # magnitude vector [out]
|
||||||
|
out_ptr, # output mag_norm_scale [out]
|
||||||
|
# Shapes
|
||||||
|
out_features,
|
||||||
|
in_features,
|
||||||
|
rank,
|
||||||
|
# Scaling
|
||||||
|
lora_scale, # float scaling factor
|
||||||
|
# Block sizes
|
||||||
|
BLOCK_IN: tl.constexpr,
|
||||||
|
BLOCK_R: tl.constexpr, # >= rank, power of 2
|
||||||
|
):
|
||||||
|
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
|
||||||
|
|
||||||
|
Each program handles one output row. B[row,:] is loaded once (small),
|
||||||
|
then we tile over in_features computing the dot product with A[:,tile]
|
||||||
|
and accumulating the squared norm.
|
||||||
|
|
||||||
|
This avoids materializing the full [out, in] B@A matrix.
|
||||||
|
"""
|
||||||
|
row = tl.program_id(0)
|
||||||
|
if row >= out_features:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Accumulate squared norm across tiles of in_features
|
||||||
|
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||||
|
|
||||||
|
for start in range(0, in_features, BLOCK_IN):
|
||||||
|
cols = start + tl.arange(0, BLOCK_IN)
|
||||||
|
col_mask = cols < in_features
|
||||||
|
|
||||||
|
# Load W[row, cols]
|
||||||
|
w_vals = tl.load(
|
||||||
|
W_ptr + row * in_features + cols,
|
||||||
|
mask=col_mask,
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute (B[row,:] @ A[:, cols]) for this tile
|
||||||
|
# Load B[row, r] as scalar and A[r, cols] as vector for each r
|
||||||
|
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||||
|
for r in tl.static_range(BLOCK_R):
|
||||||
|
# Load scalar B[row, r]
|
||||||
|
b_val = tl.load(
|
||||||
|
B_ptr + row * rank + r,
|
||||||
|
mask=(r < rank),
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
# Load vector A[r, cols]
|
||||||
|
a_vals = tl.load(
|
||||||
|
A_ptr + r * in_features + cols,
|
||||||
|
mask=(col_mask & (r < rank)),
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
ba_vals += b_val * a_vals
|
||||||
|
|
||||||
|
# Combined: W + s * (B @ A)
|
||||||
|
combined = w_vals + lora_scale * ba_vals
|
||||||
|
|
||||||
|
# Accumulate squared values
|
||||||
|
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
|
||||||
|
|
||||||
|
# Reduce to scalar norm
|
||||||
|
norm_sq = tl.sum(norm_sq_acc, axis=0)
|
||||||
|
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
|
||||||
|
|
||||||
|
# Load magnitude and compute scale
|
||||||
|
mag = tl.load(mag_ptr + row).to(tl.float32)
|
||||||
|
scale = mag / norm
|
||||||
|
|
||||||
|
tl.store(out_ptr + row, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def triton_dora_scale(
|
||||||
|
W: torch.Tensor,
|
||||||
|
W_quant,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
s: float,
|
||||||
|
magnitude: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute DoRA mag_norm_scale using fused Triton kernel.
|
||||||
|
|
||||||
|
Computes B@A row-by-row inside the kernel, avoiding the full
|
||||||
|
[out_features, in_features] materialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
W: base weight [out, in] (possibly quantized)
|
||||||
|
W_quant: quantization state
|
||||||
|
A: LoRA A [rank, in]
|
||||||
|
B: LoRA B [out, rank]
|
||||||
|
s: LoRA scaling factor
|
||||||
|
magnitude: learned magnitude [out]
|
||||||
|
dtype: compute dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
|
||||||
|
"""
|
||||||
|
# Dequantize W to [out, in]
|
||||||
|
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
|
||||||
|
|
||||||
|
out_features, in_features = W_full.shape
|
||||||
|
rank = A.shape[0]
|
||||||
|
|
||||||
|
out = torch.empty(out_features, dtype=dtype, device=W.device)
|
||||||
|
|
||||||
|
# Block sizes
|
||||||
|
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
|
||||||
|
BLOCK_R = triton.next_power_of_2(rank)
|
||||||
|
|
||||||
|
_dora_fused_norm_kernel[(out_features,)](
|
||||||
|
W_full,
|
||||||
|
B.contiguous().to(dtype),
|
||||||
|
A.contiguous().to(dtype),
|
||||||
|
magnitude.contiguous(),
|
||||||
|
out,
|
||||||
|
out_features=out_features,
|
||||||
|
in_features=in_features,
|
||||||
|
rank=rank,
|
||||||
|
lora_scale=s,
|
||||||
|
BLOCK_IN=BLOCK_IN,
|
||||||
|
BLOCK_R=BLOCK_R,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out.detach()
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,10 @@ def dequantize(
|
|||||||
# Extract quantization state
|
# Extract quantization state
|
||||||
if not isinstance(quant_state, list):
|
if not isinstance(quant_state, list):
|
||||||
# New style quant_state class
|
# New style quant_state class
|
||||||
|
# Non-double-quantized models have offset=None and state2=None
|
||||||
|
if quant_state.offset is None or quant_state.state2 is None:
|
||||||
|
# Fall back to bitsandbytes standard dequantize
|
||||||
|
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
|
||||||
absmax = quant_state.absmax.to(target_device)
|
absmax = quant_state.absmax.to(target_device)
|
||||||
shape = quant_state.shape
|
shape = quant_state.shape
|
||||||
dtype = quant_state.dtype
|
dtype = quant_state.dtype
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from torch import nn
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
|
apply_lora_embedding,
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
@@ -370,13 +371,13 @@ def apply_lora_kernel_patches(
|
|||||||
active_adapter = model.active_adapter
|
active_adapter = model.active_adapter
|
||||||
lora_config = model.model.peft_config[active_adapter]
|
lora_config = model.model.peft_config[active_adapter]
|
||||||
|
|
||||||
# Only patch if conditions are met
|
# Log what features are active
|
||||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
if lora_config.lora_dropout > 0:
|
||||||
|
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
|
||||||
if not can_patch:
|
if lora_config.bias != "none":
|
||||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
|
||||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
if lora_config.use_dora:
|
||||||
return model
|
LOG.info("LoRA kernels: DoRA enabled")
|
||||||
|
|
||||||
# This needs to be reset after patching
|
# This needs to be reset after patching
|
||||||
original_level = LOG.getEffectiveLevel()
|
original_level = LOG.getEffectiveLevel()
|
||||||
@@ -419,44 +420,33 @@ def apply_lora_kernel_patches(
|
|||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
]
|
]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A")
|
hasattr(module, "lora_A") for module in layer_modules
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_qkv:
|
if can_patch_qkv:
|
||||||
# Add optimized implementation
|
|
||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA "
|
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||||
"adapters and no lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
if cfg.lora_o_kernel:
|
if cfg.lora_o_kernel:
|
||||||
# Output patching
|
# Output patching
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
]
|
]
|
||||||
can_patch_o = all(
|
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
hasattr(module, "lora_A")
|
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if can_patch_o:
|
if can_patch_o:
|
||||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention output projection - requires LoRA "
|
"Cannot patch some attention output projection - requires LoRA adapters"
|
||||||
"adapters and no lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A")
|
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
||||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for proj in (gate_proj, up_proj, down_proj)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_mlp:
|
if can_patch_mlp:
|
||||||
@@ -464,15 +454,50 @@ def apply_lora_kernel_patches(
|
|||||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
"Cannot patch some MLP layers - requires LoRA adapters"
|
||||||
"lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Patch embedding layers (model-level, not per-layer)
|
||||||
|
if cfg.lora_embedding_kernel:
|
||||||
|
_patch_embedding_layers(model, cfg)
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
|
||||||
|
"""Patch embedding layers with fused LoRA kernel.
|
||||||
|
|
||||||
|
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
|
||||||
|
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
|
||||||
|
"""
|
||||||
|
pretrained_model = model.model
|
||||||
|
patched = 0
|
||||||
|
|
||||||
|
# Find embedding modules - check common locations
|
||||||
|
for attr_path in [
|
||||||
|
("model", "embed_tokens"),
|
||||||
|
("model", "language_model", "embed_tokens"),
|
||||||
|
]:
|
||||||
|
parent = pretrained_model
|
||||||
|
for attr in attr_path:
|
||||||
|
parent = getattr(parent, attr, None)
|
||||||
|
if parent is None:
|
||||||
|
break
|
||||||
|
if parent is not None and hasattr(parent, "lora_embedding_A"):
|
||||||
|
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
|
||||||
|
parent.forward = types.MethodType(apply_lora_embedding, parent)
|
||||||
|
patched += 1
|
||||||
|
|
||||||
|
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
|
||||||
|
# when included in target_modules. No special embedding handling needed since
|
||||||
|
# PEFT wraps it as a Linear (not Embedding) even for tied models.
|
||||||
|
|
||||||
|
if not patched:
|
||||||
|
LOG.debug("No embedding layers with LoRA found to patch")
|
||||||
|
|
||||||
|
|
||||||
class FakeMLP(nn.Module):
|
class FakeMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
placeholder MLP for triton patching
|
placeholder MLP for triton patching
|
||||||
|
|||||||
@@ -703,6 +703,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
lora_embedding_kernel: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
chunked_cross_entropy: bool | None = Field(
|
chunked_cross_entropy: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -1313,6 +1319,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
or data.get("lora_o_kernel")
|
or data.get("lora_o_kernel")
|
||||||
|
or data.get("lora_embedding_kernel")
|
||||||
):
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp_config") is not None
|
is_fsdp = data.get("fsdp_config") is not None
|
||||||
@@ -1360,7 +1367,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("adapter") in ["lora", "qlora"]:
|
if data.get("adapter") in ["lora", "qlora"]:
|
||||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||||
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
kernel_fields = [
|
||||||
|
"lora_mlp_kernel",
|
||||||
|
"lora_qkv_kernel",
|
||||||
|
"lora_o_kernel",
|
||||||
|
"lora_embedding_kernel",
|
||||||
|
]
|
||||||
if (
|
if (
|
||||||
any(data.get(k) is not None for k in kernel_fields)
|
any(data.get(k) is not None for k in kernel_fields)
|
||||||
or any(data.get(k) for k in unsloth_fields)
|
or any(data.get(k) for k in unsloth_fields)
|
||||||
@@ -1373,10 +1385,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("trust_remote_code"):
|
if data.get("trust_remote_code"):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
|
||||||
if data.get("lora_dropout") != 0:
|
|
||||||
return data
|
|
||||||
|
|
||||||
# Check multi-GPU compatibility
|
# Check multi-GPU compatibility
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||||
@@ -1398,6 +1406,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("lora_o_kernel") is None:
|
if data.get("lora_o_kernel") is None:
|
||||||
data["lora_o_kernel"] = True
|
data["lora_o_kernel"] = True
|
||||||
|
|
||||||
|
if data.get("lora_embedding_kernel") is None:
|
||||||
|
data["lora_embedding_kernel"] = True
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Auto-enabling LoRA kernel optimizations for faster training. "
|
"Auto-enabling LoRA kernel optimizations for faster training. "
|
||||||
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
||||||
|
|||||||
@@ -681,15 +681,7 @@ class LoRAValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernels_dora(cls, data):
|
def check_lora_kernels_dora(cls, data):
|
||||||
if (
|
# DoRA is now supported by lora kernels
|
||||||
data.get("lora_mlp_kernel")
|
|
||||||
or data.get("lora_qkv_kernel")
|
|
||||||
or data.get("lora_o_kernel")
|
|
||||||
) and data.get("peft_use_dora"):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
|
||||||
"compatible with DoRA at the moment."
|
|
||||||
)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
|
|
||||||
proj.base_layer = base_layer
|
proj.base_layer = base_layer
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||||
# quant_state should be None since weight is bf16, not FP8
|
# quant_state should be None since weight is bf16, not FP8
|
||||||
self.assertIsNone(quant_state)
|
self.assertIsNone(quant_state)
|
||||||
|
|
||||||
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
scale_inv = torch.ones(1)
|
scale_inv = torch.ones(1)
|
||||||
base_layer.weight_scale_inv = scale_inv
|
base_layer.weight_scale_inv = scale_inv
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||||
self.assertIs(quant_state, scale_inv)
|
self.assertIs(quant_state, scale_inv)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def mock_proj():
|
|||||||
def test_get_lora_parameters(mock_proj):
|
def test_get_lora_parameters(mock_proj):
|
||||||
"""Tests get_lora_parameters function"""
|
"""Tests get_lora_parameters function"""
|
||||||
# Test with LoRA enabled
|
# Test with LoRA enabled
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||||
|
|
||||||
assert isinstance(W, torch.Tensor)
|
assert isinstance(W, torch.Tensor)
|
||||||
assert W.shape == (128, 64)
|
assert W.shape == (128, 64)
|
||||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
|||||||
|
|
||||||
# Test with LoRA disabled
|
# Test with LoRA disabled
|
||||||
mock_proj.disable_adapters = True
|
mock_proj.disable_adapters = True
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
# Test with merged state
|
# Test with merged state
|
||||||
mock_proj.disable_adapters = False
|
mock_proj.disable_adapters = False
|
||||||
mock_proj.merged = True
|
mock_proj.merged = True
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1245
tests/e2e/kernels/test_lora_features.py
Normal file
1245
tests/e2e/kernels/test_lora_features.py
Normal file
File diff suppressed because it is too large
Load Diff
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Test LoRA kernels under FSDP2 multi-GPU training.
|
||||||
|
|
||||||
|
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
|
||||||
|
lora_embedding_kernel work correctly with FSDP2 sharding, including
|
||||||
|
with bias, dropout, and DoRA enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import require_torch_2_7_0
|
||||||
|
|
||||||
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _run_training(temp_dir, cfg):
|
||||||
|
"""Write config and launch multi-GPU training."""
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
||||||
|
"""Base config for LoRA + FSDP2 + kernel tests."""
|
||||||
|
cfg = {
|
||||||
|
"base_model": "Qwen/Qwen3-0.6B",
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:1%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 3,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
# Enable all LoRA kernels
|
||||||
|
"lora_mlp_kernel": True,
|
||||||
|
"lora_qkv_kernel": True,
|
||||||
|
"lora_o_kernel": True,
|
||||||
|
"lora_embedding_kernel": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
cfg.update(overrides)
|
||||||
|
return DictDefault(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFSDP2LoRAKernels:
|
||||||
|
"""Test LoRA kernels under FSDP2."""
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_basic(self, temp_dir):
|
||||||
|
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
|
||||||
|
cfg = _base_lora_fsdp2_config(temp_dir)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_with_dropout(self, temp_dir):
|
||||||
|
"""LoRA kernels + dropout + FSDP2."""
|
||||||
|
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_with_dora(self, temp_dir):
|
||||||
|
"""LoRA kernels + DoRA + FSDP2."""
|
||||||
|
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
|
||||||
|
"""LoRA kernels + DoRA + dropout + FSDP2."""
|
||||||
|
cfg = _base_lora_fsdp2_config(
|
||||||
|
temp_dir,
|
||||||
|
peft_use_dora=True,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_patch_conditions():
|
def test_kernel_patch_conditions():
|
||||||
"""Test various conditions that should prevent kernel patching."""
|
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
|
||||||
test_configs = [
|
test_configs = [
|
||||||
# Dropout prevents patching
|
# Dropout — kernels now support this
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
|
|||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"bias": "none",
|
"bias": "none",
|
||||||
},
|
},
|
||||||
# Bias prevents patching
|
# Bias — kernels now support this
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
|
|||||||
model = PeftModelForCausalLM(model, peft_config)
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||||
|
|
||||||
# Should not patch
|
|
||||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||||
layer = patched_model.model.model.layers[0].mlp
|
layer = patched_model.model.model.layers[0].mlp
|
||||||
|
|
||||||
# Verify no patches applied
|
# Verify patches ARE applied (dropout and bias are now supported)
|
||||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
assert (
|
||||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||||
|
or layer.forward.__func__ is apply_lora_mlp_geglu
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_config_options():
|
def test_kernel_config_options():
|
||||||
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||||
"""Test model loading with dropout non-zero should not patch."""
|
"""Test model loading with dropout non-zero DOES patch (now supported)."""
|
||||||
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
|||||||
# Load config
|
# Load config
|
||||||
cfg = load_cfg(str(path))
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
# Get original attention class
|
|
||||||
attention_cls = get_attention_cls_from_config(cfg)
|
|
||||||
|
|
||||||
# Store original state before patching
|
|
||||||
original_forward_method = attention_cls.forward
|
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
# We call modelloader as that's where the patches are applied
|
|
||||||
# despite the fact that we're not using it to load the model
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer)
|
model_loader = ModelLoader(cfg, tokenizer)
|
||||||
|
|
||||||
# Apply patch
|
# Apply patches — should succeed even with dropout > 0
|
||||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||||
|
|
||||||
# Verify patch was not applied
|
|
||||||
assert attention_cls.forward == original_forward_method
|
|
||||||
|
|
||||||
# Apply apply_lora_kernel_patches
|
|
||||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||||
|
|
||||||
# Verify patch was not applied
|
# Verify patches WERE applied (dropout is now supported by kernels)
|
||||||
layers = get_layers(model)
|
layers = get_layers(model)
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
for self_attn in find_self_attn_in_layer(layer):
|
for self_attn in find_self_attn_in_layer(layer):
|
||||||
assert not hasattr(self_attn, "apply_qkv")
|
assert hasattr(self_attn, "apply_qkv")
|
||||||
assert not hasattr(self_attn, "apply_o")
|
assert hasattr(self_attn, "apply_o")
|
||||||
|
|||||||
@@ -28,20 +28,22 @@ class TestLoRAConfigValidation:
|
|||||||
result = validate_config(valid_config)
|
result = validate_config(valid_config)
|
||||||
assert result["adapter"] == "lora"
|
assert result["adapter"] == "lora"
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not compatible with DoRA"):
|
# DoRA is now compatible with lora kernels
|
||||||
invalid_config = DictDefault(
|
dora_kernel_config = DictDefault(
|
||||||
{
|
{
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_mlp_kernel": True,
|
"lora_mlp_kernel": True,
|
||||||
"peft_use_dora": True,
|
"peft_use_dora": True,
|
||||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"learning_rate": 1e-5,
|
"learning_rate": 1e-5,
|
||||||
"base_model": "dummy_model",
|
"base_model": "dummy_model",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
validate_config(invalid_config)
|
result = validate_config(dora_kernel_config)
|
||||||
|
assert result["lora_mlp_kernel"] is True
|
||||||
|
assert result["peft_use_dora"] is True
|
||||||
|
|
||||||
def test_qlora_4bit_validation(self):
|
def test_qlora_4bit_validation(self):
|
||||||
"""Test QLoRA 4-bit configuration validation"""
|
"""Test QLoRA 4-bit configuration validation"""
|
||||||
|
|||||||
@@ -38,6 +38,11 @@ class TestLoRAParameterFreezing:
|
|||||||
|
|
||||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||||
|
mock_layer.lora_B["default"].bias = None
|
||||||
|
|
||||||
|
# Required by get_lora_parameters for dropout/DoRA extraction
|
||||||
|
mock_layer.lora_dropout = {}
|
||||||
|
mock_layer.lora_magnitude_vector = None
|
||||||
else:
|
else:
|
||||||
mock_layer.weight = base_layer.weight
|
mock_layer.weight = base_layer.weight
|
||||||
mock_layer.bias = base_layer.bias
|
mock_layer.bias = base_layer.bias
|
||||||
@@ -48,7 +53,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -62,7 +67,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test that LoRA parameters are None when adapters are merged."""
|
"""Test that LoRA parameters are None when adapters are merged."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -77,7 +82,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test parameter behavior when no adapters are present."""
|
"""Test parameter behavior when no adapters are present."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -94,7 +99,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# All parameters should be returned
|
# All parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -110,7 +115,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Check shape consistency
|
# Check shape consistency
|
||||||
assert W.shape == (512, 256)
|
assert W.shape == (512, 256)
|
||||||
@@ -124,7 +129,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert W.dtype == self.dtype
|
assert W.dtype == self.dtype
|
||||||
assert b.dtype == self.dtype
|
assert b.dtype == self.dtype
|
||||||
@@ -138,7 +143,7 @@ class TestLoRAParameterFreezing:
|
|||||||
quant_state_mock = Mock()
|
quant_state_mock = Mock()
|
||||||
layer.base_layer.weight.quant_state = quant_state_mock
|
layer.base_layer.weight.quant_state = quant_state_mock
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert quant_state == quant_state_mock
|
assert quant_state == quant_state_mock
|
||||||
|
|
||||||
@@ -157,7 +162,7 @@ class TestLoRAParameterFreezing:
|
|||||||
|
|
||||||
layer.active_adapters = ["adapter2"]
|
layer.active_adapters = ["adapter2"]
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert s == 0.2
|
assert s == 0.2
|
||||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||||
@@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration:
|
|||||||
model = get_peft_model(base_model, lora_config)
|
model = get_peft_model(base_model, lora_config)
|
||||||
lora_layer = model.base_model.model.linear
|
lora_layer = model.base_model.model.linear
|
||||||
# Test with adapters enabled
|
# Test with adapters enabled
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
||||||
assert A is not None
|
assert A is not None
|
||||||
assert B is not None
|
assert B is not None
|
||||||
assert s is not None
|
assert s is not None
|
||||||
# Test with adapters disabled
|
# Test with adapters disabled
|
||||||
model.disable_adapter_layers()
|
model.disable_adapter_layers()
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
||||||
assert A is None
|
assert A is None
|
||||||
assert B is None
|
assert B is None
|
||||||
assert s is None
|
assert s is None
|
||||||
|
|||||||
Reference in New Issue
Block a user