Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit * implementing optims * finalizing MLP LoRA kernels and progress on QKV / W kernels * updates * O projection optim * adding monkey patching logic * doc strings, typing, pre-commit fixes * updates * adding lora 8b kernels example * working on fsdp support * tests and fixes * small fixes, getting tests to pass, adding doc strings * integration tests for LoRA patching * config.qmd * remove unneeded pytest fixture * fix * review comments first pass * improving tests, attention class agnostic patching * adding support for more archs * wip SiLU / GELU impls * improved testing, small updates, etc. * slightly updating docs * rebase * fixing test_attention_patching_integration * additional review comments, fixing test in CI (hopefully) * isolating problematic patching test * relaxing allclose threshold to reduce flakiness * fixing accidental change * adding model arch agnostic attention class fetching * removing unused activations
This commit is contained in:
@@ -167,7 +167,6 @@ def train(
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
@@ -201,6 +200,8 @@ def train(
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
@@ -229,6 +230,8 @@ def train(
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
|
||||
0
src/axolotl/kernels/__init__.py
Normal file
0
src/axolotl/kernels/__init__.py
Normal file
159
src/axolotl/kernels/geglu.py
Normal file
159
src/axolotl/kernels/geglu.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Module for definition of GEGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU forward kernel.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
out_ptr: Pointer to output tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_gate.to(up.dtype)
|
||||
result = gelu_gate * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""GEGLU forward pass.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
|
||||
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains GEGLU activation output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Forward pass
|
||||
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_partial * gate
|
||||
gelu_gate = gelu_gate.to(grad_out.dtype)
|
||||
|
||||
# Forward output
|
||||
h = gelu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * gelu_gate
|
||||
|
||||
# Compute gate gradient using GELU derivative
|
||||
temp = grad_out * up
|
||||
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
||||
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
|
||||
grad_gate = temp.to(tl.float32) * dgelu_dgate
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask)
|
||||
|
||||
|
||||
def geglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""GEGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- GEGLU activation output (`h`)
|
||||
- Gradient with respect to gate (`grad_gate`)
|
||||
- Gradient with respect to up (`grad_up`)
|
||||
|
||||
Note:
|
||||
This function modifies its input tensors in-place to store results.
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
return grad_output, gate, up
|
||||
779
src/axolotl/kernels/lora.py
Normal file
779
src/axolotl/kernels/lora.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""
|
||||
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
|
||||
|
||||
See "LoRA: Low-Rank Adaptation of Large Language Models"
|
||||
(https://arxiv.org/abs/2106.09685).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from .geglu import geglu_backward, geglu_forward
|
||||
from .quantize import dequantize
|
||||
from .swiglu import swiglu_backward, swiglu_forward
|
||||
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
|
||||
|
||||
|
||||
def get_lora_parameters(
|
||||
proj: nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
QuantState | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
float | None,
|
||||
]:
|
||||
"""
|
||||
Gets LoRA parameters from a projection module.
|
||||
|
||||
Args:
|
||||
proj: The projection module to extract parameters from.
|
||||
|
||||
Returns:
|
||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||
available.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||
W = base_layer.weight
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
return W, quant_state, None, None, None
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
if hasattr(proj, "active_adapters")
|
||||
else proj.active_adapter
|
||||
)
|
||||
A = proj.lora_A[active_adapter].weight
|
||||
B = proj.lora_B[active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
return W, quant_state, A, B, s
|
||||
|
||||
|
||||
def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Efficient fused matmul + LoRA computation.
|
||||
|
||||
Args:
|
||||
X: Input tensor [*, in_features]
|
||||
W: Base weight matrix [out_features, in_features]
|
||||
W_quant: Quantization state for W
|
||||
A: LoRA A matrix [rank, in_features]
|
||||
B: LoRA B matrix [out_features, rank]
|
||||
s: LoRA scaling factor
|
||||
out: Optional output tensor for inplace operations
|
||||
|
||||
Returns:
|
||||
Result of X @ W + X @ A @ B
|
||||
"""
|
||||
dtype = X.dtype
|
||||
W = dequantize(W.t(), W_quant)
|
||||
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, _ = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
|
||||
out = torch.matmul(X, W, out=out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
|
||||
class LoRA_MLP(torch.autograd.Function):
|
||||
"""Optimized LoRA MLP implementation."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
X: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
gate_quant: object | None,
|
||||
gate_A: torch.Tensor | None,
|
||||
gate_B: torch.Tensor | None,
|
||||
gate_scale: float,
|
||||
up_weight: torch.Tensor,
|
||||
up_quant: object | None,
|
||||
up_A: torch.Tensor | None,
|
||||
up_B: torch.Tensor | None,
|
||||
up_scale: float,
|
||||
down_weight: torch.Tensor,
|
||||
down_quant: object | None,
|
||||
down_A: torch.Tensor | None,
|
||||
down_B: torch.Tensor | None,
|
||||
down_scale: float,
|
||||
activation_fn: Callable,
|
||||
activation_fn_backward: Callable,
|
||||
inplace: bool | None = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input features
|
||||
gate_weight: Gate projection weight
|
||||
gate_quant: Gate quantization state
|
||||
gate_A: Gate LoRA A matrix
|
||||
gate_B: Gate LoRA B matrix
|
||||
gate_scale: Gate LoRA scale
|
||||
up_weight: Up-projection weight
|
||||
up_quant: Up-projection quantization state
|
||||
up_A: Up-projection LoRA A matrix
|
||||
up_B: Up-projection LoRA B matrix
|
||||
up_scale: Up-projection LoRA scale
|
||||
down_weight: Down-projection weight
|
||||
down_quant: Down-projection quantization state
|
||||
down_A: Down-projection LoRA A matrix
|
||||
down_B: Down-projection LoRA B matrix
|
||||
down_scale: Down-projection LoRA scale
|
||||
activation_fn: Forward activation function
|
||||
activation_fn_backward: Backward activation function
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Output transformed by multi-layer perceptron and activation function
|
||||
"""
|
||||
# Compute projections
|
||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||
|
||||
# Activation
|
||||
hidden = activation_fn(gate, up)
|
||||
|
||||
# Down projection
|
||||
output = matmul_lora(
|
||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||
)
|
||||
|
||||
# Save for backward
|
||||
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
|
||||
ctx.scales = (gate_scale, up_scale, down_scale)
|
||||
ctx.quants = (gate_quant, up_quant, down_quant)
|
||||
ctx.weights = (gate_weight, up_weight, down_weight)
|
||||
ctx.activation_fn = activation_fn
|
||||
ctx.activation_fn_backward = activation_fn_backward
|
||||
ctx.inplace = inplace
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_output: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Performs backward pass computation for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Context object storing tensors saved during forward pass
|
||||
grad_output: Gradient of loss with respect to layer output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all inputs from forward pass:
|
||||
- Input gradient tensor (or `None`)
|
||||
- `None` for weights/quantization states
|
||||
- LoRA A/B matrix gradients (or `None`)
|
||||
- `None` for scaling factors
|
||||
- `None` for activation functions and flags
|
||||
"""
|
||||
(
|
||||
X,
|
||||
gate,
|
||||
up,
|
||||
gate_A,
|
||||
gate_B,
|
||||
up_A,
|
||||
up_B,
|
||||
down_A,
|
||||
down_B,
|
||||
) = ctx.saved_tensors
|
||||
gate_scale, up_scale, down_scale = ctx.scales
|
||||
gate_quant, up_quant, down_quant = ctx.quants
|
||||
gate_weight, up_weight, down_weight = ctx.weights
|
||||
|
||||
# Transpose all LoRA matrices
|
||||
gate_A, gate_B = (
|
||||
gate_A.t() if gate_A is not None else None,
|
||||
gate_B.t() if gate_B is not None else None,
|
||||
)
|
||||
up_A, up_B = (
|
||||
up_A.t() if up_A is not None else None,
|
||||
up_B.t() if up_B is not None else None,
|
||||
)
|
||||
down_A, down_B = (
|
||||
down_A.t() if down_A is not None else None,
|
||||
down_B.t() if down_B is not None else None,
|
||||
)
|
||||
|
||||
# Reshape inputs
|
||||
batch, seq_len, hd = X.shape
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
gate = gate.view(-1, gate.shape[-1])
|
||||
up = up.view(-1, up.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Down projection
|
||||
DW = matmul_lora(
|
||||
grad_output,
|
||||
down_weight.t(),
|
||||
down_quant,
|
||||
down_B,
|
||||
down_A,
|
||||
down_scale,
|
||||
)
|
||||
|
||||
# Activation backward
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||
|
||||
# Initialize and compute LoRA gradients
|
||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||
|
||||
if down_A is not None:
|
||||
d_down_A = h.t() @ (grad_output @ down_B.t())
|
||||
d_down_B = (down_A.t() @ h.t()) @ grad_output
|
||||
d_down_A *= down_scale
|
||||
d_down_B *= down_scale
|
||||
|
||||
if up_A is not None:
|
||||
d_up_A = X.t() @ (grad_up @ up_B.t())
|
||||
d_up_B = (up_A.t() @ X.t()) @ grad_up
|
||||
d_up_A *= up_scale
|
||||
d_up_B *= up_scale
|
||||
|
||||
if gate_A is not None:
|
||||
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
|
||||
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
|
||||
d_gate_A *= gate_scale
|
||||
d_gate_B *= gate_scale
|
||||
|
||||
# Compute input gradients
|
||||
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
|
||||
|
||||
if dX is not None:
|
||||
# Up projection gradients
|
||||
up_weight = dequantize(up_weight.t(), up_quant)
|
||||
if ctx.inplace:
|
||||
dX = torch.matmul(grad_up, up_weight.t(), out=X)
|
||||
else:
|
||||
dX = torch.matmul(grad_up, up_weight.t())
|
||||
del up_weight
|
||||
|
||||
# Note the .to(dtype) only where mixing LoRA with base weights
|
||||
if up_A is not None:
|
||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||
|
||||
# Gate projection gradients
|
||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||
dX += grad_gate @ gate_weight.t()
|
||||
del gate_weight
|
||||
|
||||
if gate_A is not None:
|
||||
dX += (
|
||||
grad_gate
|
||||
@ gate_B.to(dtype).t()
|
||||
@ (gate_scale * gate_A.to(dtype).t())
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
dX = dX.view(batch, seq_len, hd)
|
||||
|
||||
# Return gradients in correct order matching forward inputs
|
||||
return (
|
||||
dX,
|
||||
None,
|
||||
None,
|
||||
d_gate_A.t() if d_gate_A is not None else None,
|
||||
d_gate_B.t() if d_gate_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_up_A.t() if d_up_A is not None else None,
|
||||
d_up_B.t() if d_up_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_down_A.t() if d_down_A is not None else None,
|
||||
d_down_B.t() if d_down_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
swiglu_forward,
|
||||
swiglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with GEGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
geglu_forward,
|
||||
geglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LoRA_QKV(torch.autograd.Function):
|
||||
"""
|
||||
Optimized LoRA QKV implementation with quantization support.
|
||||
|
||||
Implements efficient computation of query, key, value projections with LoRA,
|
||||
supporting quantization and memory optimization.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
k_weight: torch.Tensor,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
v_weight: torch.Tensor,
|
||||
v_quant: QuantState | None,
|
||||
v_A: torch.Tensor | None,
|
||||
v_B: torch.Tensor | None,
|
||||
v_scale: float,
|
||||
inplace: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass computing Q, K, V projections with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
q_weight: Query projection weight
|
||||
q_quant: Query quantization state
|
||||
q_A: Query LoRA A matrix
|
||||
q_B: Query LoRA B matrix
|
||||
q_scale: Query LoRA scale
|
||||
k_weight: Key projection weight
|
||||
k_quant: Key quantization state
|
||||
k_A: Key LoRA A matrix
|
||||
k_B: Key LoRA B matrix
|
||||
k_scale: Key LoRA scale
|
||||
v_weight: Value projection weight
|
||||
v_quant: Value quantization state
|
||||
v_A: Value LoRA A matrix
|
||||
v_B: Value LoRA B matrix
|
||||
v_scale: Value LoRA scale
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||
|
||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||
ctx.scales = (q_scale, k_scale, v_scale)
|
||||
ctx.quants = (q_quant, k_quant, v_quant)
|
||||
ctx.weights = (q_weight, k_weight, v_weight)
|
||||
ctx.inplace = inplace
|
||||
|
||||
return Q, K, V
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
k_grad: torch.Tensor,
|
||||
v_grad: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA QKV.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
q_grad: Gradient for query projection
|
||||
k_grad: Gradient for key projection
|
||||
v_grad: Gradient for value projection
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
|
||||
q_weight, k_weight, v_weight = ctx.weights
|
||||
q_quant, k_quant, v_quant = ctx.quants
|
||||
q_scale, k_scale, v_scale = ctx.scales
|
||||
dtype = X.dtype
|
||||
|
||||
# Reshape gradients
|
||||
batch, seq_len = X.shape[:2]
|
||||
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||
v_grad = v_grad.view(-1, v_grad.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
|
||||
# Pre-transpose X once
|
||||
X_t = X.t()
|
||||
|
||||
# Initialize LoRA gradients as None
|
||||
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
|
||||
|
||||
# Compute q path LoRA gradients if adapters exist
|
||||
if A_q is not None and B_q is not None:
|
||||
A_q_scaled = (q_scale * A_q).to(dtype)
|
||||
B_q_scaled = B_q.to(dtype)
|
||||
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
|
||||
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
|
||||
|
||||
# Compute k path LoRA gradients if adapters exist
|
||||
if A_k is not None and B_k is not None:
|
||||
A_k_scaled = (k_scale * A_k).to(dtype)
|
||||
B_k_scaled = B_k.to(dtype)
|
||||
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
|
||||
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
|
||||
|
||||
# Compute v path LoRA gradients if adapters exist
|
||||
if A_v is not None and B_v is not None:
|
||||
A_v_scaled = (v_scale * A_v).to(dtype)
|
||||
B_v_scaled = B_v.to(dtype)
|
||||
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
|
||||
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
|
||||
|
||||
# Compute input gradient, reusing X memory if possible
|
||||
out_buffer = X if ctx.inplace else None
|
||||
|
||||
# Q path
|
||||
q_weight_t = dequantize(q_weight, q_quant)
|
||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||
del q_weight
|
||||
del q_weight_t
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||
|
||||
# K path
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
grad_X.addmm_(k_grad, k_weight_t)
|
||||
del k_weight
|
||||
del k_weight_t
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||
|
||||
# V path
|
||||
v_weight_t = dequantize(v_weight, v_quant)
|
||||
grad_X.addmm_(v_grad, v_weight_t)
|
||||
del v_weight
|
||||
del v_weight_t
|
||||
if A_v is not None and B_v is not None:
|
||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
||||
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
if d_B_q is not None:
|
||||
d_B_q = d_B_q.t()
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
if d_B_k is not None:
|
||||
d_B_k = d_B_k.t()
|
||||
if d_A_v is not None:
|
||||
d_A_v = d_A_v.t()
|
||||
if d_B_v is not None:
|
||||
d_B_v = d_B_v.t()
|
||||
|
||||
return (
|
||||
grad_X.view(batch, seq_len, -1),
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_v,
|
||||
d_B_v,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_qkv(
|
||||
self, X: torch.Tensor, inplace: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies LoRA to compute Query, Key, Value projections.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(
|
||||
X,
|
||||
QW,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
KW,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
VW,
|
||||
VW_quant,
|
||||
VA,
|
||||
VB,
|
||||
VS,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return Q, K, V
|
||||
|
||||
|
||||
class LoRA_O(torch.autograd.Function):
|
||||
"""Optimized LoRA implementation for output projection."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
S: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for output projection with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
W: Output projection weight
|
||||
W_quant: Weight quantization state
|
||||
A: LoRA A matrix
|
||||
B: LoRA B matrix
|
||||
S: LoRA scaling factor
|
||||
|
||||
Returns:
|
||||
Output projection tensor
|
||||
"""
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
ctx.custom_saved_tensors = (
|
||||
W,
|
||||
W_quant,
|
||||
S,
|
||||
)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
|
||||
return XW
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
dY: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA output projection.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
dY: Gradient of loss with respect to output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.reshape(-1, dY.shape[-1])
|
||||
X = X.reshape(-1, X.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Weight projection
|
||||
dY_X = X.t() @ dY
|
||||
d_A = S * dY_X @ B
|
||||
d_B = S * A @ dY_X
|
||||
|
||||
# Get derivative for dX
|
||||
W = dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
|
||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to output projection layer.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed output tensor
|
||||
"""
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
|
||||
return output
|
||||
149
src/axolotl/kernels/quantize.py
Normal file
149
src/axolotl/kernels/quantize.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
# pylint: disable=invalid-name,global-statement
|
||||
|
||||
import ctypes
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState, get_ptr
|
||||
from packaging.version import Version
|
||||
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
|
||||
CUDA_STREAM: torch.cuda.Stream | None = None
|
||||
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
|
||||
|
||||
|
||||
def dequantize(
|
||||
W: torch.Tensor,
|
||||
quant_state: QuantState | list | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
|
||||
|
||||
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
|
||||
optimized CUDA implementations. Supports both legacy list and new `QuantState`
|
||||
formats.
|
||||
|
||||
Args:
|
||||
W: Quantized weight tensor to dequantize
|
||||
quant_state: Quantization state containing metadata needed for
|
||||
dequantization. Can be either a `QuantState` object or legacy list format.
|
||||
If None, returns `W` unchanged.
|
||||
out: Optional output tensor for storing dequantized results. Must match
|
||||
expected shape and dtype if provided.
|
||||
|
||||
Returns:
|
||||
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
|
||||
input `W` was transposed.
|
||||
|
||||
Raises:
|
||||
AssertionError: If provided output tensor doesn't match expected shape / dtype.
|
||||
|
||||
Note:
|
||||
Uses CUDA streams for better performance when available in newer `bitsandbytes`
|
||||
versions (>0.43.3).
|
||||
"""
|
||||
if quant_state is None:
|
||||
return W
|
||||
|
||||
# Get the target device from input tensor W
|
||||
target_device = W.device
|
||||
|
||||
# Extract quantization state
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
offset = quant_state.offset.to(target_device)
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax.to(target_device)
|
||||
code2 = state2.code.to(target_device)
|
||||
blocksize2 = state2.blocksize
|
||||
else:
|
||||
# Legacy list format
|
||||
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
||||
absmax = absmax.to(target_device)
|
||||
offset, state2 = compressed_stats
|
||||
offset = offset.to(target_device)
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
absmax2 = absmax2.to(target_device)
|
||||
code2 = code2.to(target_device)
|
||||
|
||||
# Setup output tensor on the same device as input
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=target_device)
|
||||
else:
|
||||
assert out.shape == shape and out.dtype == dtype
|
||||
out = out.to(target_device)
|
||||
|
||||
# Dequantize statistics on the target device
|
||||
n_elements_absmax: int = absmax.numel()
|
||||
out_absmax: torch.Tensor = torch.empty(
|
||||
n_elements_absmax, dtype=torch.float32, device=target_device
|
||||
)
|
||||
ptr_out_absmax: int = get_ptr(out_absmax)
|
||||
|
||||
# Use CUDA stream if available
|
||||
if HAS_CUDA_STREAM:
|
||||
global CUDA_STREAM
|
||||
if CUDA_STREAM is None:
|
||||
CUDA_STREAM = torch.cuda.current_stream(target_device)
|
||||
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
)
|
||||
|
||||
out_absmax += offset
|
||||
|
||||
# Choose appropriate dequantization function
|
||||
fx = (
|
||||
cdequantize_blockwise_fp16_nf4
|
||||
if dtype == torch.float16
|
||||
else cdequantize_blockwise_bf16_nf4
|
||||
)
|
||||
|
||||
# Dequantize weights
|
||||
if HAS_CUDA_STREAM:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
)
|
||||
|
||||
# Handle transposed data
|
||||
is_transposed: bool = W.shape[0] == 1
|
||||
return out.t() if is_transposed else out
|
||||
163
src/axolotl/kernels/swiglu.py
Normal file
163
src/axolotl/kernels/swiglu.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Module for definition of SwiGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
|
||||
numerical stability, then converts back to original dtype for the final result.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load gate in fp32, keep up in original dtype
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
f = gate * tl.sigmoid(gate)
|
||||
f = f.to(up.dtype)
|
||||
result = f * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains forward output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load values - only convert gate to fp32
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute SiLU and forward output
|
||||
sigmoid_gate = tl.sigmoid(gate)
|
||||
silu_gate = sigmoid_gate * gate
|
||||
silu_gate = silu_gate.to(grad_out.dtype)
|
||||
h = silu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
|
||||
|
||||
# Compute gate gradient
|
||||
temp = grad_out * up
|
||||
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results with correct gradient ordering
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
|
||||
`x` is the gate tensor.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
SwiGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Forward pass output (`h`)
|
||||
- Gradient with respect to gate (`df`)
|
||||
- Gradient with respect to up-projection (`de`)
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
# After kernel execution, tensors contain:
|
||||
# grad_output: h (forward output)
|
||||
# gate: grad_gate (grad wrt gate)
|
||||
# up: grad_up (grad wrt up)
|
||||
return grad_output, gate, up
|
||||
11
src/axolotl/kernels/utils.py
Normal file
11
src/axolotl/kernels/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Utilities for `axolotl.kernels` submodules."""
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) < Version("2.4.0"):
|
||||
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
333
src/axolotl/monkeypatch/lora_kernels.py
Normal file
333
src/axolotl/monkeypatch/lora_kernels.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
|
||||
APPLY_FN_MAPPING = {
|
||||
"silu": apply_lora_mlp_swiglu,
|
||||
"gelu": apply_lora_mlp_geglu,
|
||||
}
|
||||
|
||||
|
||||
def original_apply_qkv(
|
||||
self: nn.Module, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Original implementation of QKV projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
A tuple `(query_states, key_states, value_states)` containing the projected
|
||||
states for query, key, and value.
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Original implementation of output projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
The output projection result.
|
||||
"""
|
||||
attn_output = self.o_proj(hidden_states)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
"""
|
||||
Get the appropriate attention class by inspecting the model config.
|
||||
Uses dynamic import to support any model architecture that follows
|
||||
the standard transformers naming convention.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
The appropriate attention class for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
||||
ImportError: If the model module or attention class doesn't exist
|
||||
"""
|
||||
if "base_model" not in cfg:
|
||||
raise ValueError("base_model must be specified in config")
|
||||
|
||||
# Get model config without loading the model
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
# Special case for model_type = "qwen2"
|
||||
if model_type == "qwen2":
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
||||
)
|
||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(cfg: DictDefault):
|
||||
"""
|
||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
||||
pass with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Check if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
return
|
||||
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
||||
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
"""
|
||||
Applies optimized Triton kernel patches to a PEFT model.
|
||||
|
||||
Patches a PEFT model with optimized implementations for MLP and attention
|
||||
computations. The optimizations include custom Triton kernels for activation
|
||||
functions and specialized autograd functions for LoRA computations.
|
||||
|
||||
Args:
|
||||
model: A PEFT model to be patched with optimized kernels.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
PeftModelForCausalLM: The patched model with optimized kernels.
|
||||
|
||||
Raises:
|
||||
TypeError: If the provided model is not a `PeftModelForCausalLM`.
|
||||
NotImplementedError: If the model type is not supported.
|
||||
AssertionError: If multiple adapters are active (currently unsupported).
|
||||
|
||||
Note:
|
||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
||||
function will skip patching if these conditions aren't met.
|
||||
"""
|
||||
if not isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||
|
||||
# Get active LoRA adapter config
|
||||
if hasattr(model, "active_adapters"):
|
||||
assert (
|
||||
len(model.active_adapters) == 1
|
||||
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
|
||||
active_adapter = model.active_adapters[0]
|
||||
else:
|
||||
active_adapter = model.active_adapter
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
|
||||
# This needs to be reset after patching
|
||||
original_level = LOG.getEffectiveLevel()
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
# Choose activation based on model type
|
||||
activation = model.config.hidden_act
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
# Patch each layer
|
||||
for layer in model.model.model.layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
original_apply_qkv, layer.self_attn
|
||||
)
|
||||
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
|
||||
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
gate_proj = layer.mlp.gate_proj
|
||||
up_proj = layer.mlp.up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
apply_fn = APPLY_FN_MAPPING[activation]
|
||||
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_o:
|
||||
layer.self_attn.apply_o = types.MethodType(
|
||||
apply_lora_o, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
@@ -175,6 +175,7 @@ def train(
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -185,6 +186,7 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
"""
|
||||
Module for pydantic models for configuration
|
||||
"""
|
||||
|
||||
"""Module with Pydantic models for configuration."""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
@@ -810,6 +807,10 @@ class AxolotlInputConfig(
|
||||
unsloth_rms_norm: Optional[bool] = None
|
||||
unsloth_rope: Optional[bool] = None
|
||||
|
||||
lora_mlp_kernel: Optional[bool] = None
|
||||
lora_qkv_kernel: Optional[bool] = None
|
||||
lora_o_kernel: Optional[bool] = None
|
||||
|
||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||
fsdp: Optional[List[str]] = None
|
||||
fsdp_config: Optional[Dict[str, Any]] = None
|
||||
@@ -1534,12 +1535,42 @@ class AxolotlInputConfig(
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_8bit(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
is_lora_kernel = any(
|
||||
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
is_unsloth_lora = any(
|
||||
data.get(k)
|
||||
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
)
|
||||
if is_lora_kernel and is_unsloth_lora:
|
||||
raise ValueError(
|
||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
@@ -1672,6 +1703,29 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_lora_kernels(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_deepspeed = data.get("deepspeed") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
)
|
||||
if is_deepspeed:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_adopt_torch_version(cls, data):
|
||||
|
||||
@@ -414,6 +414,7 @@ class ModelLoader:
|
||||
has_remote_code = "AutoModelForCausalLM" in auto_map_config
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
@@ -425,10 +426,6 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
self.patch_loss_llama()
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
elif self.cfg.is_llama_derived_model:
|
||||
self.patch_llama_derived_model()
|
||||
|
||||
@@ -442,6 +439,11 @@ class ModelLoader:
|
||||
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -472,9 +474,7 @@ class ModelLoader:
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
def patch_loss_llama(self) -> None:
|
||||
"""
|
||||
Patch loss functions
|
||||
"""
|
||||
"""Patch loss functions and other optimizations"""
|
||||
if self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_fa_llama_cross_entropy,
|
||||
@@ -494,15 +494,14 @@ class ModelLoader:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def patch_llama_derived_model(self) -> None:
|
||||
"""
|
||||
Modify all llama derived models in one block
|
||||
"""
|
||||
"""Modify all llama derived models in one block"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
@@ -1013,7 +1012,8 @@ class ModelLoader:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(dist_dtype)
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
# TODO: Deprecate this.
|
||||
def apply_unsloth_lora_patch(self) -> None:
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||
|
||||
@@ -1027,6 +1027,16 @@ class ModelLoader:
|
||||
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
if (
|
||||
self.cfg.lora_mlp_kernel
|
||||
or self.cfg.lora_qkv_kernel
|
||||
or self.cfg.lora_o_kernel
|
||||
):
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
|
||||
apply_lora_kernel_patches(self.model, self.cfg)
|
||||
|
||||
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
self.apply_patches()
|
||||
self.set_auto_model_loader()
|
||||
@@ -1171,6 +1181,7 @@ class ModelLoader:
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
|
||||
self.apply_unsloth_lora_patch()
|
||||
self.apply_lora_patch()
|
||||
|
||||
for _ in range(3):
|
||||
|
||||
Reference in New Issue
Block a user