Activation function Triton kernels, LoRA custom autograd functions (#2324)

* LoRA + activation fn Triton kernels: initial commit

* implementing optims

* finalizing MLP LoRA kernels and progress on QKV / W kernels

* updates

* O projection optim

* adding monkey patching logic

* doc strings, typing, pre-commit fixes

* updates

* adding lora 8b kernels example

* working on fsdp support

* tests and fixes

* small fixes, getting tests to pass, adding doc strings

* integration tests for LoRA patching

* config.qmd

* remove unneeded pytest fixture

* fix

* review comments first pass

* improving tests, attention class agnostic patching

* adding support for more archs

* wip SiLU / GELU impls

* improved testing, small updates, etc.

* slightly updating docs

* rebase

* fixing test_attention_patching_integration

* additional review comments, fixing test in CI (hopefully)

* isolating problematic patching test

* relaxing allclose threshold to reduce flakiness

* fixing accidental change

* adding model arch agnostic attention class fetching

* removing unused activations
This commit is contained in:
Dan Saunders
2025-02-17 14:23:15 -05:00
committed by GitHub
parent 97a2fa2781
commit 3d8425fa91
22 changed files with 3102 additions and 22 deletions

View File

@@ -4,8 +4,8 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,4 @@
"""
modal application to run axolotl gpu tests in Modal
"""
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os

View File

@@ -305,6 +305,13 @@ lora_modules_to_save:
lora_fan_in_fan_out: false
# Apply custom LoRA autograd functions and activation function Triton kernels for
# speed and memory savings
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`

127
docs/lora_optims.qmd Normal file
View File

@@ -0,0 +1,127 @@
---
title: "LoRA Optimizations"
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
LoRA fine-tuning"
---
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
- `llama`
- `mistral`
- `qwen2`
- `gemma`
- `gemma2`
<details>
The set of models we support is currently limited by our attention patching strategy,
which assumes (and replaces) specific code blocks for query / key / value and output
projections:
```python
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
```
Is replaced with:
```python
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
```
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
We welcome testing of other model architectures and / or PRs to expand our patching
logic to be compatible with more of them.
</details>
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
`lora_o_kernel` enable the fused query-key-value projection and optimized output
projection, respectively.
```yaml
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
```
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- AMD can be used with experimental Triton support by setting the environment variable `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.
## Implementation details
### Custom autograd functions
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
LoRA and base weight computations together and provides a single, efficient backward
pass for the entire MLP block.
For attention components, similar optimizations are provided through a function that
handles the query, key, and value projections, and a function that handles the output
projection. They are designed to work with the existing `transformers` attention
implementation via some monkey-patching logic.
### Triton kernels
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
improved speed and memory performance. These kernels handle both the forward and
backward passes.
### Integration
The custom autograd functions and Triton kernels are designed to work together. The
autograd function manages the high-level computation flow and gradient tracking, while
calling the Triton kernels for the activation function computation. During the backward
pass, the kernel computes both the activation output and the required gradients, which
the autograd function then uses to compute the final gradients for the entire
computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

View File

@@ -0,0 +1,82 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
# Currently, we don't support dropout with our custom Triton kernels
# lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# These options enable our custom Triton kernels / autograd
# functions for MLP and attention calculations
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -167,7 +167,6 @@ def train(
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -201,6 +200,8 @@ def train(
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
@@ -229,6 +230,8 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)

View File

View File

@@ -0,0 +1,159 @@
"""
Module for definition of GEGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch
import triton
import triton.language as tl
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
@triton.jit
def _geglu_fwd_kernel(
gate_ptr,
up_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""GEGLU forward kernel.
Args:
gate_ptr: Pointer to gate tensor [*, hidden_dim].
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
out_ptr: Pointer to output tensor [*, hidden_dim].
n_elements: Total number of elements in the input tensors.
BLOCK_SIZE: Size of thread blocks for parallel computation.
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute activation in fp32 then convert back
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
gelu_gate = gelu_gate.to(up.dtype)
result = gelu_gate * up
tl.store(out_ptr + offsets, result, mask=mask)
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""GEGLU forward pass.
Args:
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
Returns:
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
"""
batch, seq_len, hidden_dim = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
_geglu_fwd_kernel[grid](
gate_ptr=gate,
up_ptr=up,
out_ptr=out,
n_elements=n_elements,
BLOCK_SIZE=1024,
)
return out
@triton.jit
def _geglu_bwd_kernel(
grad_out_ptr,
gate_ptr,
up_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""GEGLU backward kernel. Stores gradient results in-place.
Args:
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
gate_ptr: Pointer to gate tensor [*, hidden_dim].
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
n_elements: Total number of elements in the input tensors.
BLOCK_SIZE: Size of thread blocks for parallel computation.
Note:
After kernel execution, tensors are modified in-place:
- `grad_out_ptr` contains GEGLU activation output (`h`)
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
- `up_ptr` contains gradient w.r.t up (`grad_up`)
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Forward pass
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
gelu_gate = gelu_partial * gate
gelu_gate = gelu_gate.to(grad_out.dtype)
# Forward output
h = gelu_gate * up
# Compute gradients
grad_up = grad_out * gelu_gate
# Compute gate gradient using GELU derivative
temp = grad_out * up
t = 0.3989422804014327 # 1/sqrt(2*pi)
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
grad_gate = temp.to(tl.float32) * dgelu_dgate
grad_gate = grad_gate.to(grad_out.dtype)
# Store results
tl.store(grad_out_ptr + offsets, h, mask=mask)
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
tl.store(up_ptr + offsets, grad_up, mask=mask)
def geglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""GEGLU backward pass using in-place operations.
Args:
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
Returns:
Tuple containing:
- GEGLU activation output (`h`)
- Gradient with respect to gate (`grad_gate`)
- Gradient with respect to up (`grad_up`)
Note:
This function modifies its input tensors in-place to store results.
"""
n_elements = grad_output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
_geglu_bwd_kernel[grid](
grad_out_ptr=grad_output,
gate_ptr=gate,
up_ptr=up,
n_elements=n_elements,
BLOCK_SIZE=1024,
)
return grad_output, gate, up

779
src/axolotl/kernels/lora.py Normal file
View File

@@ -0,0 +1,779 @@
"""
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
See "LoRA: Low-Rank Adaptation of Large Language Models"
(https://arxiv.org/abs/2106.09685).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
]:
"""
Gets LoRA parameters from a projection module.
Args:
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
active_adapter = (
proj.active_adapters[0]
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
Args:
X: Input tensor [*, in_features]
W: Base weight matrix [out_features, in_features]
W_quant: Quantization state for W
A: LoRA A matrix [rank, in_features]
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
class LoRA_MLP(torch.autograd.Function):
"""Optimized LoRA MLP implementation."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_quant: object | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_quant: object | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_quant: object | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
activation_fn: Callable,
activation_fn_backward: Callable,
inplace: bool | None = True,
) -> torch.Tensor:
"""
Forward pass for LoRA MLP.
Args:
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
Returns:
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_quant, down_A, down_B, down_scale
)
# Save for backward
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
ctx.scales = (gate_scale, up_scale, down_scale)
ctx.quants = (gate_quant, up_quant, down_quant)
ctx.weights = (gate_weight, up_weight, down_weight)
ctx.activation_fn = activation_fn
ctx.activation_fn_backward = activation_fn_backward
ctx.inplace = inplace
return output
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_output: torch.Tensor,
) -> tuple[
torch.Tensor | None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
Args:
ctx: Context object storing tensors saved during forward pass
grad_output: Gradient of loss with respect to layer output
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
"""
(
X,
gate,
up,
gate_A,
gate_B,
up_A,
up_B,
down_A,
down_B,
) = ctx.saved_tensors
gate_scale, up_scale, down_scale = ctx.scales
gate_quant, up_quant, down_quant = ctx.quants
gate_weight, up_weight, down_weight = ctx.weights
# Transpose all LoRA matrices
gate_A, gate_B = (
gate_A.t() if gate_A is not None else None,
gate_B.t() if gate_B is not None else None,
)
up_A, up_B = (
up_A.t() if up_A is not None else None,
up_B.t() if up_B is not None else None,
)
down_A, down_B = (
down_A.t() if down_A is not None else None,
down_B.t() if down_B is not None else None,
)
# Reshape inputs
batch, seq_len, hd = X.shape
grad_output = grad_output.view(-1, grad_output.shape[-1])
X = X.view(-1, X.shape[-1])
gate = gate.view(-1, gate.shape[-1])
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
DW = matmul_lora(
grad_output,
down_weight.t(),
down_quant,
down_B,
down_A,
down_scale,
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
if down_A is not None:
d_down_A = h.t() @ (grad_output @ down_B.t())
d_down_B = (down_A.t() @ h.t()) @ grad_output
d_down_A *= down_scale
d_down_B *= down_scale
if up_A is not None:
d_up_A = X.t() @ (grad_up @ up_B.t())
d_up_B = (up_A.t() @ X.t()) @ grad_up
d_up_A *= up_scale
d_up_B *= up_scale
if gate_A is not None:
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
d_gate_A *= gate_scale
d_gate_B *= gate_scale
# Compute input gradients
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
dX = torch.matmul(grad_up, up_weight.t())
del up_weight
# Note the .to(dtype) only where mixing LoRA with base weights
if up_A is not None:
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
del gate_weight
if gate_A is not None:
dX += (
grad_gate
@ gate_B.to(dtype).t()
@ (gate_scale * gate_A.to(dtype).t())
)
# Reshape back
dX = dX.view(batch, seq_len, hd)
# Return gradients in correct order matching forward inputs
return (
dX,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
)
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with SwiGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
swiglu_forward,
swiglu_backward,
inplace,
)
return out
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with GEGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
geglu_forward,
geglu_backward,
inplace,
)
return out
class LoRA_QKV(torch.autograd.Function):
"""
Optimized LoRA QKV implementation with quantization support.
Implements efficient computation of query, key, value projections with LoRA,
supporting quantization and memory optimization.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
v_scale: float,
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass computing Q, K, V projections with LoRA.
Args:
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
v_scale: Value LoRA scale
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_fwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
v_grad: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
]:
"""
Backward pass computing gradients for LoRA QKV.
Args:
ctx: Autograd context
q_grad: Gradient for query projection
k_grad: Gradient for key projection
v_grad: Gradient for value projection
Returns:
Tuple containing gradients for all forward inputs
"""
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
q_weight, k_weight, v_weight = ctx.weights
q_quant, k_quant, v_quant = ctx.quants
q_scale, k_scale, v_scale = ctx.scales
dtype = X.dtype
# Reshape gradients
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
v_grad = v_grad.view(-1, v_grad.shape[-1])
X = X.view(-1, X.shape[-1])
# Pre-transpose X once
X_t = X.t()
# Initialize LoRA gradients as None
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
# Compute q path LoRA gradients if adapters exist
if A_q is not None and B_q is not None:
A_q_scaled = (q_scale * A_q).to(dtype)
B_q_scaled = B_q.to(dtype)
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
# Compute k path LoRA gradients if adapters exist
if A_k is not None and B_k is not None:
A_k_scaled = (k_scale * A_k).to(dtype)
B_k_scaled = B_k.to(dtype)
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
# Compute v path LoRA gradients if adapters exist
if A_v is not None and B_v is not None:
A_v_scaled = (v_scale * A_v).to(dtype)
B_v_scaled = B_v.to(dtype)
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
# Compute input gradient, reusing X memory if possible
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
if d_B_q is not None:
d_B_q = d_B_q.t()
if d_A_k is not None:
d_A_k = d_A_k.t()
if d_B_k is not None:
d_B_k = d_B_k.t()
if d_A_v is not None:
d_A_v = d_A_v.t()
if d_B_v is not None:
d_B_v = d_B_v.t()
return (
grad_X.view(batch, seq_len, -1),
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
d_A_v,
d_B_v,
None,
None,
)
def apply_lora_qkv(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query, Key, Value projections.
Args:
X: Input tensor
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
QW_quant,
QA,
QB,
QS,
KW,
KW_quant,
KA,
KB,
KS,
VW,
VW_quant,
VA,
VB,
VS,
inplace,
)
return Q, K, V
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
Args:
ctx: Autograd context
X: Input tensor
W: Output projection weight
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
S: LoRA scaling factor
Returns:
Output projection tensor
"""
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (
W,
W_quant,
S,
)
ctx.save_for_backward(A, B, X)
return XW
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
dY: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
]:
"""
Backward pass computing gradients for LoRA output projection.
Args:
ctx: Autograd context
dY: Gradient of loss with respect to output
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1])
X = X.reshape(-1, X.shape[-1])
dtype = X.dtype
# Weight projection
dY_X = X.t() @ dY
d_A = S * dY_X @ B
d_B = S * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
"""
Applies LoRA to output projection layer.
Args:
X: Input tensor
Returns:
Transformed output tensor
"""
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
return output

View File

@@ -0,0 +1,149 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes
import bitsandbytes as bnb
import torch
from bitsandbytes.functional import QuantState, get_ptr
from packaging.version import Version
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
CUDA_STREAM: torch.cuda.Stream | None = None
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
def dequantize(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
optimized CUDA implementations. Supports both legacy list and new `QuantState`
formats.
Args:
W: Quantized weight tensor to dequantize
quant_state: Quantization state containing metadata needed for
dequantization. Can be either a `QuantState` object or legacy list format.
If None, returns `W` unchanged.
out: Optional output tensor for storing dequantized results. Must match
expected shape and dtype if provided.
Returns:
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
input `W` was transposed.
Raises:
AssertionError: If provided output tensor doesn't match expected shape / dtype.
Note:
Uses CUDA streams for better performance when available in newer `bitsandbytes`
versions (>0.43.3).
"""
if quant_state is None:
return W
# Get the target device from input tensor W
target_device = W.device
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device)
blocksize2 = state2.blocksize
else:
# Legacy list format
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
absmax = absmax.to(target_device)
offset, state2 = compressed_stats
offset = offset.to(target_device)
absmax2, code2, blocksize2, _, _, _, _ = state2
absmax2 = absmax2.to(target_device)
code2 = code2.to(target_device)
# Setup output tensor on the same device as input
if out is None:
out = torch.empty(shape, dtype=dtype, device=target_device)
else:
assert out.shape == shape and out.dtype == dtype
out = out.to(target_device)
# Dequantize statistics on the target device
n_elements_absmax: int = absmax.numel()
out_absmax: torch.Tensor = torch.empty(
n_elements_absmax, dtype=torch.float32, device=target_device
)
ptr_out_absmax: int = get_ptr(out_absmax)
# Use CUDA stream if available
if HAS_CUDA_STREAM:
global CUDA_STREAM
if CUDA_STREAM is None:
CUDA_STREAM = torch.cuda.current_stream(target_device)
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
CUDA_STREAM,
)
else:
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
)
out_absmax += offset
# Choose appropriate dequantization function
fx = (
cdequantize_blockwise_fp16_nf4
if dtype == torch.float16
else cdequantize_blockwise_bf16_nf4
)
# Dequantize weights
if HAS_CUDA_STREAM:
fx(
get_ptr(None),
get_ptr(W),
ptr_out_absmax,
get_ptr(out),
ctypes.c_int(blocksize),
ctypes.c_int(out.numel()),
CUDA_STREAM,
)
else:
fx(
get_ptr(None),
get_ptr(W),
ptr_out_absmax,
get_ptr(out),
ctypes.c_int(blocksize),
ctypes.c_int(out.numel()),
)
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out

View File

@@ -0,0 +1,163 @@
"""
Module for definition of SwiGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
import torch
import triton
import triton.language as tl
@triton.jit
def _swiglu_fwd_kernel(
gate_ptr,
up_ptr,
out_ptr,
n_elements,
block_size: tl.constexpr,
):
"""
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
numerical stability, then converts back to original dtype for the final result.
Args:
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
n_elements: Total number of elements in the input tensors.
block_size: Size of thread blocks for parallel computation.
"""
block_idx = tl.program_id(0)
offsets = block_idx * block_size + tl.arange(0, block_size)
mask = offsets < n_elements
# Load gate in fp32, keep up in original dtype
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute activation in fp32 then convert back
f = gate * tl.sigmoid(gate)
f = f.to(up.dtype)
result = f * up
tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def _swiglu_bwd_kernel(
grad_out_ptr,
gate_ptr,
up_ptr,
n_elements,
block_size: tl.constexpr,
):
"""
SwiGLU backward kernel. Stores gradient results in-place.
Args:
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
n_elements: Total number of elements in the input tensors.
block_size: Size of thread blocks for parallel computation.
Note:
After kernel execution, tensors are modified in-place:
- `grad_out_ptr` contains forward output (`h`)
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
- `up_ptr` contains gradient w.r.t up (`grad_up`)
"""
block_idx = tl.program_id(0)
offsets = block_idx * block_size + tl.arange(0, block_size)
mask = offsets < n_elements
# Load values - only convert gate to fp32
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute SiLU and forward output
sigmoid_gate = tl.sigmoid(gate)
silu_gate = sigmoid_gate * gate
silu_gate = silu_gate.to(grad_out.dtype)
h = silu_gate * up
# Compute gradients
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
# Compute gate gradient
temp = grad_out * up
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
grad_gate = grad_gate.to(grad_out.dtype)
# Store results with correct gradient ordering
tl.store(grad_out_ptr + offsets, h, mask=mask)
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
# pylint: disable=unnecessary-lambda-assignment
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
`x` is the gate tensor.
Args:
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
Returns:
Output tensor of shape `[batch, seq_len, hidden_dim]`.
"""
batch, seq_len, hidden_dim = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
_swiglu_fwd_kernel[grid](
gate_ptr=gate,
up_ptr=up,
out_ptr=out,
n_elements=n_elements,
block_size=1024,
)
return out
# pylint: disable=unnecessary-lambda-assignment
def swiglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
SwiGLU backward pass using in-place operations.
Args:
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
Returns:
Tuple containing:
- Forward pass output (`h`)
- Gradient with respect to gate (`df`)
- Gradient with respect to up-projection (`de`)
"""
n_elements = grad_output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
_swiglu_bwd_kernel[grid](
grad_out_ptr=grad_output,
gate_ptr=gate,
up_ptr=up,
n_elements=n_elements,
block_size=1024,
)
# After kernel execution, tensors contain:
# grad_output: h (forward output)
# gate: grad_gate (grad wrt gate)
# up: grad_up (grad wrt up)
return grad_output, gate, up

View File

@@ -0,0 +1,11 @@
"""Utilities for `axolotl.kernels` submodules."""
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")

View File

@@ -0,0 +1,333 @@
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
import importlib
import inspect
import logging
import types
from typing import Type
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
APPLY_FN_MAPPING = {
"silu": apply_lora_mlp_swiglu,
"gelu": apply_lora_mlp_geglu,
}
def original_apply_qkv(
self: nn.Module, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Original implementation of QKV projection without optimizations.
Args:
self: The attention module instance.
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
Returns:
A tuple `(query_states, key_states, value_states)` containing the projected
states for query, key, and value.
"""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Original implementation of output projection without optimizations.
Args:
self: The attention module instance.
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
Returns:
The output projection result.
"""
attn_output = self.o_proj(hidden_states)
return attn_output
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
"""
Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
The appropriate attention class for the model.
Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")
# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Raises:
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
attention_cls = get_attention_cls_from_config(cfg)
# Check if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
return
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Load necessary imports
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
"""
Applies optimized Triton kernel patches to a PEFT model.
Patches a PEFT model with optimized implementations for MLP and attention
computations. The optimizations include custom Triton kernels for activation
functions and specialized autograd functions for LoRA computations.
Args:
model: A PEFT model to be patched with optimized kernels.
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
PeftModelForCausalLM: The patched model with optimized kernels.
Raises:
TypeError: If the provided model is not a `PeftModelForCausalLM`.
NotImplementedError: If the model type is not supported.
AssertionError: If multiple adapters are active (currently unsupported).
Note:
The optimizations require LoRA adapters with no dropout and no bias terms. The
function will skip patching if these conditions aren't met.
"""
if not isinstance(model, PeftModelForCausalLM):
raise TypeError("Model must be a PeftModelForCausalLM")
# Get active LoRA adapter config
if hasattr(model, "active_adapters"):
assert (
len(model.active_adapters) == 1
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
active_adapter = model.active_adapters[0]
else:
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
LOG.setLevel(logging.INFO)
# Choose activation based on model type
activation = model.config.hidden_act
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
# Patch each layer
for layer in model.model.model.layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(
original_apply_qkv, layer.self_attn
)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
if cfg.lora_mlp_kernel:
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp.up_proj
down_proj = layer.mlp.down_proj
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
layer.self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)
return model

View File

@@ -175,6 +175,7 @@ def train(
LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -185,6 +186,7 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

View File

@@ -1,7 +1,4 @@
"""
Module for pydantic models for configuration
"""
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging
@@ -810,6 +807,10 @@ class AxolotlInputConfig(
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None
lora_mlp_kernel: Optional[bool] = None
lora_qkv_kernel: Optional[bool] = None
lora_o_kernel: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
@@ -1534,12 +1535,42 @@ class AxolotlInputConfig(
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
is_lora_kernel = any(
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
is_unsloth_lora = any(
data.get(k)
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
)
if is_lora_kernel and is_unsloth_lora:
raise ValueError(
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
@@ -1672,6 +1703,29 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
is_deepspeed = data.get("deepspeed") is not None
if capabilities and capabilities.get("n_gpu", 0) > 1:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
)
if is_deepspeed:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):

View File

@@ -414,6 +414,7 @@ class ModelLoader:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
@@ -425,10 +426,6 @@ class ModelLoader:
if self.cfg.is_llama_derived_model:
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif self.cfg.is_llama_derived_model:
self.patch_llama_derived_model()
@@ -442,6 +439,11 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -472,9 +474,7 @@ class ModelLoader:
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
"""Patch loss functions and other optimizations"""
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
@@ -494,15 +494,14 @@ class ModelLoader:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
"""Modify all llama derived models in one block"""
self.patch_loss_llama()
if self.cfg.flash_attention:
@@ -1013,7 +1012,8 @@ class ModelLoader:
if hasattr(module, "weight"):
module.to(dist_dtype)
def apply_lora_patch(self) -> None:
# TODO: Deprecate this.
def apply_unsloth_lora_patch(self) -> None:
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
@@ -1027,6 +1027,16 @@ class ModelLoader:
integrate_rope_embeddings()
def apply_lora_patch(self) -> None:
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
or self.cfg.lora_o_kernel
):
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(self.model, self.cfg)
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
self.apply_patches()
self.set_auto_model_loader()
@@ -1171,6 +1181,7 @@ class ModelLoader:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
self.apply_unsloth_lora_patch()
self.apply_lora_patch()
for _ in range(3):

View File

@@ -0,0 +1,76 @@
"""Tests for GEGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.geglu import geglu_backward, geglu_forward
def test_geglu_forward_shape():
"""Test that GEGLU forward pass preserves expected shapes."""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = geglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_geglu_forward_values():
"""Test GEGLU forward pass matches PyTorch reference implementation."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = geglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.gelu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_geglu_backward():
"""Test GEGLU backward pass matches PyTorch autograd."""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
gelu_gate = F.gelu(gate)
torch_out = gelu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
def test_geglu_inplace_preservation():
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
geglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"

View File

@@ -0,0 +1,531 @@
"""Tests for LoRA custom autograd."""
# pylint: disable=invalid-name,redefined-outer-name
import pytest
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from axolotl.kernels.geglu import geglu_backward, geglu_forward
from axolotl.kernels.lora import (
LoRA_MLP,
LoRA_O,
LoRA_QKV,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
get_lora_parameters,
matmul_lora,
)
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
@pytest.fixture
def mock_quantstate():
"""Creates a mock QuantState for testing"""
shape = (64, 64)
n_blocks = shape[0] # Assuming blockwise quantization along first dimension
# Create nested state first
nested_state = QuantState(
absmax=torch.ones(n_blocks, device="cuda"), # One value per block
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15
dtype=torch.float16,
blocksize=64,
quant_type="nf4",
offset=None,
state2=None,
)
# Create main state with nested state
return QuantState(
absmax=torch.ones(n_blocks, device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=64,
quant_type="nf4",
offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"),
state2=nested_state,
)
@pytest.fixture
def sample_tensors():
"""Creates sample tensors for testing"""
torch.manual_seed(42)
batch_size, seq_len, hidden_dim = 2, 3, 64
rank = 8
out_dim = hidden_dim
return {
"X": torch.randn(
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
"seq": seq_len,
"hidden": hidden_dim,
"out": out_dim,
"rank": rank,
},
}
@pytest.fixture
def mock_proj():
"""Creates a mock projection module for testing."""
class MockProj(nn.Module):
"""Mock projection class."""
def __init__(self, in_features=64, out_features=128, rank=8):
super().__init__()
self.base_layer = nn.Linear(in_features, out_features)
self.base_layer.to("cuda")
self.lora_A = nn.ModuleDict(
{"default": nn.Linear(in_features, rank, bias=False).to("cuda")}
)
self.lora_B = nn.ModuleDict(
{"default": nn.Linear(rank, out_features, bias=False).to("cuda")}
)
self.scaling = {"default": 0.5}
self.active_adapter = "default"
self.disable_adapters = False
self.merged = False
return MockProj()
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
)
def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):
"""Tests LoRA_MLP directly with different activation functions"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
# Create linear layers
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
# Test SwiGLU path
X.requires_grad = True
output = LoRA_MLP.apply(
X,
gate_proj.weight,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
None, # down_quant
None, # down_A
None, # down_B
None, # down_scale
activation_forward,
activation_backward,
True, # inplace
)
assert output.shape == X.shape
assert not torch.isnan(output).any()
# Test backward pass
loss = output.sum()
loss.backward()
assert X.grad is not None
assert not torch.isnan(X.grad).any()
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
)
def test_lora_mlp_with_adapters(
sample_tensors, activation_forward, activation_backward
):
"""Tests LoRA_MLP with LoRA adapters"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
# Create LoRA components
gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16)
down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16)
scale = 0.5
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
X.requires_grad = True
gate_A.requires_grad = True
gate_B.requires_grad = True
up_A.requires_grad = True
up_B.requires_grad = True
down_A.requires_grad = True
down_B.requires_grad = True
# Forward pass with adapters
output = LoRA_MLP.apply(
X,
gate_proj.weight,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
None,
up_A,
up_B,
scale,
down_proj.weight,
None,
down_A,
down_B,
scale,
activation_forward,
activation_backward,
True,
)
assert output.shape == X.shape
assert not torch.isnan(output).any()
# Test backward pass
loss = output.sum()
loss.backward()
# Check all gradients
assert X.grad is not None
assert gate_A.grad is not None
assert gate_B.grad is not None
assert up_A.grad is not None
assert up_B.grad is not None
assert down_A.grad is not None
assert down_B.grad is not None
assert not torch.isnan(X.grad).any()
assert not torch.isnan(gate_A.grad).any()
assert not torch.isnan(gate_B.grad).any()
assert not torch.isnan(up_A.grad).any()
assert not torch.isnan(up_B.grad).any()
assert not torch.isnan(down_A.grad).any()
assert not torch.isnan(down_B.grad).any()
def test_lora_qkv(sample_tensors):
"""Tests LoRA QKV implementation with and without adapters"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
rank = shapes["rank"]
# Create base weights
q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
# Create LoRA matrices
q_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
q_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
k_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
k_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
v_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
v_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
scale = 0.5
X.requires_grad = True
# Test without LoRA adapters
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
True,
)
assert Q1.shape == K1.shape == V1.shape == X.shape
loss1 = (Q1 + K1 + V1).sum()
loss1.backward()
assert X.grad is not None
# Clear gradients
X.grad = None
# Test with LoRA adapters
Q2, K2, V2 = LoRA_QKV.apply(
X,
q_weight,
None,
q_A,
q_B,
scale,
k_weight,
None,
k_A,
k_B,
scale,
v_weight,
None,
v_A,
v_B,
scale,
True,
)
assert Q2.shape == K2.shape == V2.shape == X.shape
loss2 = (Q2 + K2 + V2).sum()
loss2.backward()
# Check gradients
assert X.grad is not None
assert q_A.grad is not None
assert q_B.grad is not None
assert k_A.grad is not None
assert k_B.grad is not None
assert v_A.grad is not None
assert v_B.grad is not None
# Check for NaN values
assert not torch.isnan(X.grad).any()
assert not torch.isnan(q_A.grad).any()
assert not torch.isnan(q_B.grad).any()
assert not torch.isnan(k_A.grad).any()
assert not torch.isnan(k_B.grad).any()
assert not torch.isnan(v_A.grad).any()
assert not torch.isnan(v_B.grad).any()
def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
# Test backward pass
loss = output.sum()
loss.backward()
assert X.grad is not None
def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
scale = 0.5
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@pytest.mark.parametrize(
"batch,seq,hidden,rank,out",
[
(1, 1, 32, 4, 64),
(2, 3, 64, 8, 128),
(4, 5, 128, 16, 256),
],
)
def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, None, A, B, scale)
assert result.shape == (batch, seq, out)
def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
X.requires_grad = True
A.requires_grad = True
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, None, A, B, scale)
loss = out.sum()
# Backward pass
loss.backward()
assert X.grad is not None
assert A.grad is not None
assert B.grad is not None
assert not torch.isnan(X.grad).any()
assert not torch.isnan(A.grad).any()
assert not torch.isnan(B.grad).any()
@pytest.mark.parametrize(
"apply_function",
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
)
def test_inplace_operations(sample_tensors, apply_function):
"""Tests inplace operation behavior"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
# Create MLP with both inplace=True and inplace=False
mlp = type(
"MLPModule",
(),
{
"gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
device="cuda", dtype=torch.float16
),
"up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
device="cuda", dtype=torch.float16
),
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
device="cuda", dtype=torch.float16
),
},
)
out1 = apply_function(mlp, X.clone(), inplace=True)
out2 = apply_function(mlp, X.clone(), inplace=False)
assert torch.allclose(out1, out2, rtol=1e-3)

View File

@@ -0,0 +1,103 @@
"""Tests for quantization utility functions."""
# pylint: disable=invalid-name
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
def test_dequantize_null_state():
"""Test that dequantize returns input unchanged when quant_state is None"""
W = torch.randn(32, 32)
assert torch.equal(dequantize(W, None), W)
def test_dequantize_shape_preservation():
"""Test that dequantization preserves expected shapes"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
state2=QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape == shape
assert result.dtype == torch.float16
assert result.device == W.device
def test_dequantize_transposed():
"""Test that transposed input produces transposed output"""
shape = (32, 32)
W = torch.randn(1, shape[1], device="cuda") # Transposed input
quant_state = QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(1, dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape[0] == shape[0]
def test_dequantize_output_tensor():
"""Test dequantization with provided output tensor"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
out = torch.empty(shape, dtype=torch.float16, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state, out=out)
assert result is out

View File

@@ -0,0 +1,78 @@
"""Tests for SwiGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
def test_swiglu_forward_shape():
"""Test that SwiGLU forward pass preserves expected shapes"""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = swiglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_swiglu_forward_values():
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = swiglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.silu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_swiglu_backward():
"""Test SwiGLU backward pass matches PyTorch autograd"""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
silu_gate = F.silu(gate)
torch_out = silu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, our_grad_gate, our_grad_up = swiglu_backward(
grad_output_clone, gate_clone, up_clone
)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
def test_swiglu_inplace_preservation():
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
swiglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"

View File

@@ -0,0 +1,414 @@
"""Integration tests for LoRA activation and attention kernels."""
# pylint: disable=redefined-outer-name
import pytest
import torch
from accelerate.state import PartialState
from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "HuggingFaceTB/SmolLM2-135M",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
]
@pytest.fixture(autouse=True)
def init_accelerate():
"""Initialize Accelerate state before tests."""
_ = PartialState()
@pytest.fixture
def small_llama_model():
"""Create a small LLaMA model for testing."""
config = {
"vocab_size": 100,
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
}
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward")
# Clean up
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):
"""Test SwiGLU activation in LoRA MLP context."""
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
cfg = DictDefault({"lora_mlp_kernel": True})
# Apply patches
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify patches
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
# Test forward pass
batch_size, seq_len = 2, 10
hidden_states = torch.randn(
batch_size, seq_len, model.config.hidden_size, device=model.device
)
position_ids = (
torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1)
)
cos, sin = model.model.model.rotary_emb(hidden_states, position_ids)
inputs = {
"hidden_states": hidden_states,
"attention_mask": None,
"position_embeddings": (cos, sin),
"output_attentions": False,
"use_cache": False,
"past_key_value": None,
}
# Compare outputs
with torch.no_grad():
original_output = model.model.model.layers[0](**inputs)[0]
patched_output = layer(**inputs)[0]
assert torch.allclose(original_output, patched_output, rtol=1e-4)
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
)
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify patches
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu
# Test end-to-end
inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long)
with torch.no_grad():
original_output = model(inputs).logits
patched_output = patched_model(inputs).logits
assert torch.allclose(original_output, patched_output, rtol=1e-4)
@pytest.mark.parametrize(
"model_name,expected_activation",
[
("HuggingFaceTB/SmolLM2-135M", apply_lora_mlp_swiglu),
("mhenrichsen/gemma-2b", apply_lora_mlp_geglu),
],
)
def test_model_specific_activation(model_name, expected_activation):
"""Test that each model type gets the correct activation function."""
model = AutoModelForCausalLM.from_pretrained(model_name)
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is expected_activation
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
},
]
for config in test_configs:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_config_options():
"""Test that kernel configuration options are respected."""
# Test different configurations
test_configs = [
(
{"lora_mlp_kernel": True, "lora_qkv_kernel": False, "lora_o_kernel": False},
lambda layer: (
layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
and layer.self_attn.apply_o.__func__ is not apply_lora_o
),
),
(
{"lora_mlp_kernel": False, "lora_qkv_kernel": True, "lora_o_kernel": False},
lambda layer: (
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv
and layer.self_attn.apply_o.__func__ is not apply_lora_o
),
),
(
{"lora_mlp_kernel": False, "lora_qkv_kernel": False, "lora_o_kernel": True},
lambda layer: (
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
and layer.self_attn.apply_o.__func__ is apply_lora_o
),
),
]
for config_dict, check_fn in test_configs:
# Create fresh model for each test
config = {
"vocab_size": 100,
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
}
small_llama_model = LlamaForCausalLM(LlamaConfig(**config))
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": [
"gate_proj",
"up_proj",
"down_proj",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
cfg = DictDefault(config_dict)
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify only requested optimizations were applied
for layer in patched_model.model.model.layers:
assert check_fn(layer), f"Failed for config: {config_dict}"
# Clean up
del model
del small_llama_model
del patched_model
def get_lora_config():
"""Get standard LoRA configuration for testing."""
return {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
def get_test_inputs(model, seq_length=20):
"""Generate test inputs for model evaluation."""
return torch.randint(
0,
model.config.vocab_size,
(1, seq_length),
device=model.device,
dtype=torch.long,
)
@pytest.mark.parametrize("model_config", MODEL_CONFIGS)
def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
)
# Apply LoRA configuration
peft_config = get_peft_config(get_lora_config())
model = PeftModelForCausalLM(model, peft_config)
# Apply kernel patches
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify correct activation function
layer = patched_model.model.model.layers[0]
assert (
layer.mlp.forward.__func__ is model_config["expected_activation"]
), f"Wrong activation for {model_config['name']}"
# Test forward pass
inputs = get_test_inputs(model)
with torch.no_grad():
original_output = model(inputs).logits
patched_output = patched_model(inputs).logits
# Check outputs match
assert torch.allclose(
original_output, patched_output, rtol=1e-4
), f"Outputs don't match for {model_config['name']}"
# pylint: disable=duplicate-code
def test_kernel_training_integration():
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
# Verify correct activation function
layer = model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu