Activation function Triton kernels, LoRA custom autograd functions (#2324)

* LoRA + activation fn Triton kernels: initial commit

* implementing optims

* finalizing MLP LoRA kernels and progress on QKV / W kernels

* updates

* O projection optim

* adding monkey patching logic

* doc strings, typing, pre-commit fixes

* updates

* adding lora 8b kernels example

* working on fsdp support

* tests and fixes

* small fixes, getting tests to pass, adding doc strings

* integration tests for LoRA patching

* config.qmd

* remove unneeded pytest fixture

* fix

* review comments first pass

* improving tests, attention class agnostic patching

* adding support for more archs

* wip SiLU / GELU impls

* improved testing, small updates, etc.

* slightly updating docs

* rebase

* fixing test_attention_patching_integration

* additional review comments, fixing test in CI (hopefully)

* isolating problematic patching test

* relaxing allclose threshold to reduce flakiness

* fixing accidental change

* adding model arch agnostic attention class fetching

* removing unused activations
This commit is contained in:
Dan Saunders
2025-02-17 14:23:15 -05:00
committed by GitHub
parent 97a2fa2781
commit 3d8425fa91
22 changed files with 3102 additions and 22 deletions

View File

@@ -167,7 +167,6 @@ def train(
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -201,6 +200,8 @@ def train(
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
@@ -229,6 +230,8 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)

View File

View File

@@ -0,0 +1,159 @@
"""
Module for definition of GEGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch
import triton
import triton.language as tl
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
@triton.jit
def _geglu_fwd_kernel(
gate_ptr,
up_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""GEGLU forward kernel.
Args:
gate_ptr: Pointer to gate tensor [*, hidden_dim].
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
out_ptr: Pointer to output tensor [*, hidden_dim].
n_elements: Total number of elements in the input tensors.
BLOCK_SIZE: Size of thread blocks for parallel computation.
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute activation in fp32 then convert back
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
gelu_gate = gelu_gate.to(up.dtype)
result = gelu_gate * up
tl.store(out_ptr + offsets, result, mask=mask)
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""GEGLU forward pass.
Args:
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
Returns:
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
"""
batch, seq_len, hidden_dim = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
_geglu_fwd_kernel[grid](
gate_ptr=gate,
up_ptr=up,
out_ptr=out,
n_elements=n_elements,
BLOCK_SIZE=1024,
)
return out
@triton.jit
def _geglu_bwd_kernel(
grad_out_ptr,
gate_ptr,
up_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""GEGLU backward kernel. Stores gradient results in-place.
Args:
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
gate_ptr: Pointer to gate tensor [*, hidden_dim].
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
n_elements: Total number of elements in the input tensors.
BLOCK_SIZE: Size of thread blocks for parallel computation.
Note:
After kernel execution, tensors are modified in-place:
- `grad_out_ptr` contains GEGLU activation output (`h`)
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
- `up_ptr` contains gradient w.r.t up (`grad_up`)
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Forward pass
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
gelu_gate = gelu_partial * gate
gelu_gate = gelu_gate.to(grad_out.dtype)
# Forward output
h = gelu_gate * up
# Compute gradients
grad_up = grad_out * gelu_gate
# Compute gate gradient using GELU derivative
temp = grad_out * up
t = 0.3989422804014327 # 1/sqrt(2*pi)
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
grad_gate = temp.to(tl.float32) * dgelu_dgate
grad_gate = grad_gate.to(grad_out.dtype)
# Store results
tl.store(grad_out_ptr + offsets, h, mask=mask)
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
tl.store(up_ptr + offsets, grad_up, mask=mask)
def geglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""GEGLU backward pass using in-place operations.
Args:
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
Returns:
Tuple containing:
- GEGLU activation output (`h`)
- Gradient with respect to gate (`grad_gate`)
- Gradient with respect to up (`grad_up`)
Note:
This function modifies its input tensors in-place to store results.
"""
n_elements = grad_output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
_geglu_bwd_kernel[grid](
grad_out_ptr=grad_output,
gate_ptr=gate,
up_ptr=up,
n_elements=n_elements,
BLOCK_SIZE=1024,
)
return grad_output, gate, up

779
src/axolotl/kernels/lora.py Normal file
View File

@@ -0,0 +1,779 @@
"""
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
See "LoRA: Low-Rank Adaptation of Large Language Models"
(https://arxiv.org/abs/2106.09685).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
]:
"""
Gets LoRA parameters from a projection module.
Args:
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
active_adapter = (
proj.active_adapters[0]
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
Args:
X: Input tensor [*, in_features]
W: Base weight matrix [out_features, in_features]
W_quant: Quantization state for W
A: LoRA A matrix [rank, in_features]
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
class LoRA_MLP(torch.autograd.Function):
"""Optimized LoRA MLP implementation."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_quant: object | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_quant: object | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_quant: object | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
activation_fn: Callable,
activation_fn_backward: Callable,
inplace: bool | None = True,
) -> torch.Tensor:
"""
Forward pass for LoRA MLP.
Args:
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
Returns:
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_quant, down_A, down_B, down_scale
)
# Save for backward
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
ctx.scales = (gate_scale, up_scale, down_scale)
ctx.quants = (gate_quant, up_quant, down_quant)
ctx.weights = (gate_weight, up_weight, down_weight)
ctx.activation_fn = activation_fn
ctx.activation_fn_backward = activation_fn_backward
ctx.inplace = inplace
return output
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_output: torch.Tensor,
) -> tuple[
torch.Tensor | None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
Args:
ctx: Context object storing tensors saved during forward pass
grad_output: Gradient of loss with respect to layer output
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
"""
(
X,
gate,
up,
gate_A,
gate_B,
up_A,
up_B,
down_A,
down_B,
) = ctx.saved_tensors
gate_scale, up_scale, down_scale = ctx.scales
gate_quant, up_quant, down_quant = ctx.quants
gate_weight, up_weight, down_weight = ctx.weights
# Transpose all LoRA matrices
gate_A, gate_B = (
gate_A.t() if gate_A is not None else None,
gate_B.t() if gate_B is not None else None,
)
up_A, up_B = (
up_A.t() if up_A is not None else None,
up_B.t() if up_B is not None else None,
)
down_A, down_B = (
down_A.t() if down_A is not None else None,
down_B.t() if down_B is not None else None,
)
# Reshape inputs
batch, seq_len, hd = X.shape
grad_output = grad_output.view(-1, grad_output.shape[-1])
X = X.view(-1, X.shape[-1])
gate = gate.view(-1, gate.shape[-1])
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
DW = matmul_lora(
grad_output,
down_weight.t(),
down_quant,
down_B,
down_A,
down_scale,
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
if down_A is not None:
d_down_A = h.t() @ (grad_output @ down_B.t())
d_down_B = (down_A.t() @ h.t()) @ grad_output
d_down_A *= down_scale
d_down_B *= down_scale
if up_A is not None:
d_up_A = X.t() @ (grad_up @ up_B.t())
d_up_B = (up_A.t() @ X.t()) @ grad_up
d_up_A *= up_scale
d_up_B *= up_scale
if gate_A is not None:
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
d_gate_A *= gate_scale
d_gate_B *= gate_scale
# Compute input gradients
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
dX = torch.matmul(grad_up, up_weight.t())
del up_weight
# Note the .to(dtype) only where mixing LoRA with base weights
if up_A is not None:
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
del gate_weight
if gate_A is not None:
dX += (
grad_gate
@ gate_B.to(dtype).t()
@ (gate_scale * gate_A.to(dtype).t())
)
# Reshape back
dX = dX.view(batch, seq_len, hd)
# Return gradients in correct order matching forward inputs
return (
dX,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
)
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with SwiGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
swiglu_forward,
swiglu_backward,
inplace,
)
return out
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with GEGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
geglu_forward,
geglu_backward,
inplace,
)
return out
class LoRA_QKV(torch.autograd.Function):
"""
Optimized LoRA QKV implementation with quantization support.
Implements efficient computation of query, key, value projections with LoRA,
supporting quantization and memory optimization.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
v_scale: float,
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass computing Q, K, V projections with LoRA.
Args:
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
v_scale: Value LoRA scale
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_fwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
v_grad: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
]:
"""
Backward pass computing gradients for LoRA QKV.
Args:
ctx: Autograd context
q_grad: Gradient for query projection
k_grad: Gradient for key projection
v_grad: Gradient for value projection
Returns:
Tuple containing gradients for all forward inputs
"""
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
q_weight, k_weight, v_weight = ctx.weights
q_quant, k_quant, v_quant = ctx.quants
q_scale, k_scale, v_scale = ctx.scales
dtype = X.dtype
# Reshape gradients
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
v_grad = v_grad.view(-1, v_grad.shape[-1])
X = X.view(-1, X.shape[-1])
# Pre-transpose X once
X_t = X.t()
# Initialize LoRA gradients as None
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
# Compute q path LoRA gradients if adapters exist
if A_q is not None and B_q is not None:
A_q_scaled = (q_scale * A_q).to(dtype)
B_q_scaled = B_q.to(dtype)
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
# Compute k path LoRA gradients if adapters exist
if A_k is not None and B_k is not None:
A_k_scaled = (k_scale * A_k).to(dtype)
B_k_scaled = B_k.to(dtype)
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
# Compute v path LoRA gradients if adapters exist
if A_v is not None and B_v is not None:
A_v_scaled = (v_scale * A_v).to(dtype)
B_v_scaled = B_v.to(dtype)
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
# Compute input gradient, reusing X memory if possible
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
if d_B_q is not None:
d_B_q = d_B_q.t()
if d_A_k is not None:
d_A_k = d_A_k.t()
if d_B_k is not None:
d_B_k = d_B_k.t()
if d_A_v is not None:
d_A_v = d_A_v.t()
if d_B_v is not None:
d_B_v = d_B_v.t()
return (
grad_X.view(batch, seq_len, -1),
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
d_A_v,
d_B_v,
None,
None,
)
def apply_lora_qkv(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query, Key, Value projections.
Args:
X: Input tensor
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
QW_quant,
QA,
QB,
QS,
KW,
KW_quant,
KA,
KB,
KS,
VW,
VW_quant,
VA,
VB,
VS,
inplace,
)
return Q, K, V
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
Args:
ctx: Autograd context
X: Input tensor
W: Output projection weight
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
S: LoRA scaling factor
Returns:
Output projection tensor
"""
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (
W,
W_quant,
S,
)
ctx.save_for_backward(A, B, X)
return XW
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
dY: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
]:
"""
Backward pass computing gradients for LoRA output projection.
Args:
ctx: Autograd context
dY: Gradient of loss with respect to output
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1])
X = X.reshape(-1, X.shape[-1])
dtype = X.dtype
# Weight projection
dY_X = X.t() @ dY
d_A = S * dY_X @ B
d_B = S * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
"""
Applies LoRA to output projection layer.
Args:
X: Input tensor
Returns:
Transformed output tensor
"""
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
return output

View File

@@ -0,0 +1,149 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes
import bitsandbytes as bnb
import torch
from bitsandbytes.functional import QuantState, get_ptr
from packaging.version import Version
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
CUDA_STREAM: torch.cuda.Stream | None = None
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
def dequantize(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
optimized CUDA implementations. Supports both legacy list and new `QuantState`
formats.
Args:
W: Quantized weight tensor to dequantize
quant_state: Quantization state containing metadata needed for
dequantization. Can be either a `QuantState` object or legacy list format.
If None, returns `W` unchanged.
out: Optional output tensor for storing dequantized results. Must match
expected shape and dtype if provided.
Returns:
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
input `W` was transposed.
Raises:
AssertionError: If provided output tensor doesn't match expected shape / dtype.
Note:
Uses CUDA streams for better performance when available in newer `bitsandbytes`
versions (>0.43.3).
"""
if quant_state is None:
return W
# Get the target device from input tensor W
target_device = W.device
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device)
blocksize2 = state2.blocksize
else:
# Legacy list format
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
absmax = absmax.to(target_device)
offset, state2 = compressed_stats
offset = offset.to(target_device)
absmax2, code2, blocksize2, _, _, _, _ = state2
absmax2 = absmax2.to(target_device)
code2 = code2.to(target_device)
# Setup output tensor on the same device as input
if out is None:
out = torch.empty(shape, dtype=dtype, device=target_device)
else:
assert out.shape == shape and out.dtype == dtype
out = out.to(target_device)
# Dequantize statistics on the target device
n_elements_absmax: int = absmax.numel()
out_absmax: torch.Tensor = torch.empty(
n_elements_absmax, dtype=torch.float32, device=target_device
)
ptr_out_absmax: int = get_ptr(out_absmax)
# Use CUDA stream if available
if HAS_CUDA_STREAM:
global CUDA_STREAM
if CUDA_STREAM is None:
CUDA_STREAM = torch.cuda.current_stream(target_device)
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
CUDA_STREAM,
)
else:
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
)
out_absmax += offset
# Choose appropriate dequantization function
fx = (
cdequantize_blockwise_fp16_nf4
if dtype == torch.float16
else cdequantize_blockwise_bf16_nf4
)
# Dequantize weights
if HAS_CUDA_STREAM:
fx(
get_ptr(None),
get_ptr(W),
ptr_out_absmax,
get_ptr(out),
ctypes.c_int(blocksize),
ctypes.c_int(out.numel()),
CUDA_STREAM,
)
else:
fx(
get_ptr(None),
get_ptr(W),
ptr_out_absmax,
get_ptr(out),
ctypes.c_int(blocksize),
ctypes.c_int(out.numel()),
)
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out

View File

@@ -0,0 +1,163 @@
"""
Module for definition of SwiGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
import torch
import triton
import triton.language as tl
@triton.jit
def _swiglu_fwd_kernel(
gate_ptr,
up_ptr,
out_ptr,
n_elements,
block_size: tl.constexpr,
):
"""
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
numerical stability, then converts back to original dtype for the final result.
Args:
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
n_elements: Total number of elements in the input tensors.
block_size: Size of thread blocks for parallel computation.
"""
block_idx = tl.program_id(0)
offsets = block_idx * block_size + tl.arange(0, block_size)
mask = offsets < n_elements
# Load gate in fp32, keep up in original dtype
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute activation in fp32 then convert back
f = gate * tl.sigmoid(gate)
f = f.to(up.dtype)
result = f * up
tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def _swiglu_bwd_kernel(
grad_out_ptr,
gate_ptr,
up_ptr,
n_elements,
block_size: tl.constexpr,
):
"""
SwiGLU backward kernel. Stores gradient results in-place.
Args:
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
n_elements: Total number of elements in the input tensors.
block_size: Size of thread blocks for parallel computation.
Note:
After kernel execution, tensors are modified in-place:
- `grad_out_ptr` contains forward output (`h`)
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
- `up_ptr` contains gradient w.r.t up (`grad_up`)
"""
block_idx = tl.program_id(0)
offsets = block_idx * block_size + tl.arange(0, block_size)
mask = offsets < n_elements
# Load values - only convert gate to fp32
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute SiLU and forward output
sigmoid_gate = tl.sigmoid(gate)
silu_gate = sigmoid_gate * gate
silu_gate = silu_gate.to(grad_out.dtype)
h = silu_gate * up
# Compute gradients
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
# Compute gate gradient
temp = grad_out * up
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
grad_gate = grad_gate.to(grad_out.dtype)
# Store results with correct gradient ordering
tl.store(grad_out_ptr + offsets, h, mask=mask)
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
# pylint: disable=unnecessary-lambda-assignment
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
`x` is the gate tensor.
Args:
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
Returns:
Output tensor of shape `[batch, seq_len, hidden_dim]`.
"""
batch, seq_len, hidden_dim = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
_swiglu_fwd_kernel[grid](
gate_ptr=gate,
up_ptr=up,
out_ptr=out,
n_elements=n_elements,
block_size=1024,
)
return out
# pylint: disable=unnecessary-lambda-assignment
def swiglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
SwiGLU backward pass using in-place operations.
Args:
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
Returns:
Tuple containing:
- Forward pass output (`h`)
- Gradient with respect to gate (`df`)
- Gradient with respect to up-projection (`de`)
"""
n_elements = grad_output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
_swiglu_bwd_kernel[grid](
grad_out_ptr=grad_output,
gate_ptr=gate,
up_ptr=up,
n_elements=n_elements,
block_size=1024,
)
# After kernel execution, tensors contain:
# grad_output: h (forward output)
# gate: grad_gate (grad wrt gate)
# up: grad_up (grad wrt up)
return grad_output, gate, up

View File

@@ -0,0 +1,11 @@
"""Utilities for `axolotl.kernels` submodules."""
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")

View File

@@ -0,0 +1,333 @@
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
import importlib
import inspect
import logging
import types
from typing import Type
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
APPLY_FN_MAPPING = {
"silu": apply_lora_mlp_swiglu,
"gelu": apply_lora_mlp_geglu,
}
def original_apply_qkv(
self: nn.Module, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Original implementation of QKV projection without optimizations.
Args:
self: The attention module instance.
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
Returns:
A tuple `(query_states, key_states, value_states)` containing the projected
states for query, key, and value.
"""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Original implementation of output projection without optimizations.
Args:
self: The attention module instance.
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
Returns:
The output projection result.
"""
attn_output = self.o_proj(hidden_states)
return attn_output
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
"""
Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
The appropriate attention class for the model.
Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")
# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Raises:
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
attention_cls = get_attention_cls_from_config(cfg)
# Check if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
return
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Load necessary imports
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
"""
Applies optimized Triton kernel patches to a PEFT model.
Patches a PEFT model with optimized implementations for MLP and attention
computations. The optimizations include custom Triton kernels for activation
functions and specialized autograd functions for LoRA computations.
Args:
model: A PEFT model to be patched with optimized kernels.
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
PeftModelForCausalLM: The patched model with optimized kernels.
Raises:
TypeError: If the provided model is not a `PeftModelForCausalLM`.
NotImplementedError: If the model type is not supported.
AssertionError: If multiple adapters are active (currently unsupported).
Note:
The optimizations require LoRA adapters with no dropout and no bias terms. The
function will skip patching if these conditions aren't met.
"""
if not isinstance(model, PeftModelForCausalLM):
raise TypeError("Model must be a PeftModelForCausalLM")
# Get active LoRA adapter config
if hasattr(model, "active_adapters"):
assert (
len(model.active_adapters) == 1
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
active_adapter = model.active_adapters[0]
else:
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
LOG.setLevel(logging.INFO)
# Choose activation based on model type
activation = model.config.hidden_act
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
# Patch each layer
for layer in model.model.model.layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(
original_apply_qkv, layer.self_attn
)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
if cfg.lora_mlp_kernel:
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp.up_proj
down_proj = layer.mlp.down_proj
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
layer.self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)
return model

View File

@@ -175,6 +175,7 @@ def train(
LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -185,6 +186,7 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

View File

@@ -1,7 +1,4 @@
"""
Module for pydantic models for configuration
"""
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging
@@ -810,6 +807,10 @@ class AxolotlInputConfig(
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None
lora_mlp_kernel: Optional[bool] = None
lora_qkv_kernel: Optional[bool] = None
lora_o_kernel: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
@@ -1534,12 +1535,42 @@ class AxolotlInputConfig(
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
is_lora_kernel = any(
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
is_unsloth_lora = any(
data.get(k)
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
)
if is_lora_kernel and is_unsloth_lora:
raise ValueError(
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
@@ -1672,6 +1703,29 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
is_deepspeed = data.get("deepspeed") is not None
if capabilities and capabilities.get("n_gpu", 0) > 1:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
)
if is_deepspeed:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):

View File

@@ -414,6 +414,7 @@ class ModelLoader:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
@@ -425,10 +426,6 @@ class ModelLoader:
if self.cfg.is_llama_derived_model:
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif self.cfg.is_llama_derived_model:
self.patch_llama_derived_model()
@@ -442,6 +439,11 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -472,9 +474,7 @@ class ModelLoader:
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
"""Patch loss functions and other optimizations"""
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
@@ -494,15 +494,14 @@ class ModelLoader:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
"""Modify all llama derived models in one block"""
self.patch_loss_llama()
if self.cfg.flash_attention:
@@ -1013,7 +1012,8 @@ class ModelLoader:
if hasattr(module, "weight"):
module.to(dist_dtype)
def apply_lora_patch(self) -> None:
# TODO: Deprecate this.
def apply_unsloth_lora_patch(self) -> None:
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
@@ -1027,6 +1027,16 @@ class ModelLoader:
integrate_rope_embeddings()
def apply_lora_patch(self) -> None:
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
or self.cfg.lora_o_kernel
):
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(self.model, self.cfg)
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
self.apply_patches()
self.set_auto_model_loader()
@@ -1171,6 +1181,7 @@ class ModelLoader:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
self.apply_unsloth_lora_patch()
self.apply_lora_patch()
for _ in range(3):