feat: handle lora kernels compat with torchao
This commit is contained in:
@@ -15,7 +15,7 @@ from torch import nn
|
|||||||
from torch.distributed.tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
from .geglu import geglu_backward, geglu_forward
|
from .geglu import geglu_backward, geglu_forward
|
||||||
from .quantize import dequantize
|
from .quantize import dequantize_weight
|
||||||
from .swiglu import swiglu_backward, swiglu_forward
|
from .swiglu import swiglu_backward, swiglu_forward
|
||||||
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
|
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
|
||||||
|
|
||||||
@@ -46,6 +46,12 @@ def get_lora_parameters(
|
|||||||
W = base_layer.weight
|
W = base_layer.weight
|
||||||
b = base_layer.bias
|
b = base_layer.bias
|
||||||
|
|
||||||
|
# Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy
|
||||||
|
# attribute access to the underlying tensor subclass, so torchao methods like
|
||||||
|
# .dequantize() or .get_original_weight() would not be visible.
|
||||||
|
if isinstance(W, DTensor):
|
||||||
|
W = W.full_tensor()
|
||||||
|
|
||||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||||
quant_state = getattr(W, "quant_state", None)
|
quant_state = getattr(W, "quant_state", None)
|
||||||
return W, b, quant_state, None, None, None
|
return W, b, quant_state, None, None, None
|
||||||
@@ -86,6 +92,7 @@ def matmul_lora(
|
|||||||
B: torch.Tensor | None,
|
B: torch.Tensor | None,
|
||||||
s: float | None,
|
s: float | None,
|
||||||
out: torch.Tensor | None = None,
|
out: torch.Tensor | None = None,
|
||||||
|
transpose: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Efficient fused matmul + LoRA computation.
|
Efficient fused matmul + LoRA computation.
|
||||||
@@ -98,12 +105,15 @@ def matmul_lora(
|
|||||||
B: LoRA B matrix [out_features, rank]
|
B: LoRA B matrix [out_features, rank]
|
||||||
s: LoRA scaling factor
|
s: LoRA scaling factor
|
||||||
out: Optional output tensor for inplace operations
|
out: Optional output tensor for inplace operations
|
||||||
|
transpose: If True (default), transpose W before matmul (forward path).
|
||||||
|
Set to False for backward paths where W is already in the correct layout.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Result of X @ W + X @ A @ B
|
Result of X @ W + X @ A @ B
|
||||||
"""
|
"""
|
||||||
dtype = X.dtype
|
dtype = X.dtype
|
||||||
W = dequantize(W.t(), W_quant)
|
is_quantized = W_quant is not None or type(W) is not torch.Tensor
|
||||||
|
W = dequantize_weight(W, W_quant, transpose=transpose)
|
||||||
|
|
||||||
reshape = False
|
reshape = False
|
||||||
if X.dim() == 3:
|
if X.dim() == 3:
|
||||||
@@ -112,7 +122,7 @@ def matmul_lora(
|
|||||||
reshape = True
|
reshape = True
|
||||||
|
|
||||||
out = torch.matmul(X, W, out=out)
|
out = torch.matmul(X, W, out=out)
|
||||||
if W_quant is not None:
|
if is_quantized:
|
||||||
del W
|
del W
|
||||||
|
|
||||||
if A is not None:
|
if A is not None:
|
||||||
@@ -292,15 +302,16 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
up = up.view(-1, up.shape[-1])
|
up = up.view(-1, up.shape[-1])
|
||||||
dtype = X.dtype
|
dtype = X.dtype
|
||||||
|
|
||||||
# Down projection
|
# Down projection (backward: no transpose needed, W is already [out, in])
|
||||||
grad_down = matmul_lora(
|
grad_down = matmul_lora(
|
||||||
grad_output,
|
grad_output,
|
||||||
down_weight.t(),
|
down_weight,
|
||||||
None,
|
None,
|
||||||
down_quant,
|
down_quant,
|
||||||
down_B,
|
down_B,
|
||||||
down_A,
|
down_A,
|
||||||
down_scale,
|
down_scale,
|
||||||
|
transpose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Activation backward
|
# Activation backward
|
||||||
@@ -332,7 +343,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
|
|
||||||
if dX is not None:
|
if dX is not None:
|
||||||
# Up projection gradients
|
# Up projection gradients
|
||||||
up_weight = dequantize(up_weight.t(), up_quant)
|
up_weight = dequantize_weight(up_weight, up_quant, transpose=True)
|
||||||
if ctx.inplace:
|
if ctx.inplace:
|
||||||
dX = torch.matmul(grad_up, up_weight.t(), out=X)
|
dX = torch.matmul(grad_up, up_weight.t(), out=X)
|
||||||
else:
|
else:
|
||||||
@@ -344,7 +355,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||||
|
|
||||||
# Gate projection gradients
|
# Gate projection gradients
|
||||||
gate_weight = dequantize(gate_weight, gate_quant)
|
gate_weight = dequantize_weight(gate_weight, gate_quant)
|
||||||
dX += grad_gate @ gate_weight
|
dX += grad_gate @ gate_weight
|
||||||
del gate_weight
|
del gate_weight
|
||||||
|
|
||||||
@@ -631,7 +642,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
out_buffer = X if ctx.inplace else None
|
out_buffer = X if ctx.inplace else None
|
||||||
|
|
||||||
# Q path
|
# Q path
|
||||||
q_weight_t = dequantize(q_weight, q_quant)
|
q_weight_t = dequantize_weight(q_weight, q_quant)
|
||||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||||
del q_weight
|
del q_weight
|
||||||
del q_weight_t
|
del q_weight_t
|
||||||
@@ -639,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||||
|
|
||||||
# K path
|
# K path
|
||||||
k_weight_t = dequantize(k_weight, k_quant)
|
k_weight_t = dequantize_weight(k_weight, k_quant)
|
||||||
grad_X.addmm_(k_grad, k_weight_t)
|
grad_X.addmm_(k_grad, k_weight_t)
|
||||||
del k_weight
|
del k_weight
|
||||||
del k_weight_t
|
del k_weight_t
|
||||||
@@ -647,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||||
|
|
||||||
# V path
|
# V path
|
||||||
v_weight_t = dequantize(v_weight, v_quant)
|
v_weight_t = dequantize_weight(v_weight, v_quant)
|
||||||
grad_X.addmm_(v_grad, v_weight_t)
|
grad_X.addmm_(v_grad, v_weight_t)
|
||||||
del v_weight
|
del v_weight
|
||||||
del v_weight_t
|
del v_weight_t
|
||||||
@@ -810,7 +821,7 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
d_B = s * A @ dY_X
|
d_B = s * A @ dY_X
|
||||||
|
|
||||||
# Get derivative for dX
|
# Get derivative for dX
|
||||||
W = dequantize(W.t(), W_quant)
|
W = dequantize_weight(W, W_quant, transpose=True)
|
||||||
dX = dY @ W.t()
|
dX = dY @ W.t()
|
||||||
del W
|
del W
|
||||||
|
|
||||||
|
|||||||
@@ -146,3 +146,43 @@ def dequantize(
|
|||||||
# Handle transposed data
|
# Handle transposed data
|
||||||
is_transposed: bool = W.shape[0] == 1
|
is_transposed: bool = W.shape[0] == 1
|
||||||
return out.t() if is_transposed else out
|
return out.t() if is_transposed else out
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_weight(
|
||||||
|
W: torch.Tensor,
|
||||||
|
quant_state: QuantState | list | None = None,
|
||||||
|
transpose: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Unified dequantization for both torchao and bnb quantized weights.
|
||||||
|
|
||||||
|
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
|
||||||
|
using the appropriate instance method. For bnb Params4bit, delegates to the
|
||||||
|
optimized CUDA kernel in ``dequantize``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
W: Quantized weight tensor ``[out_features, in_features]``.
|
||||||
|
quant_state: bnb ``QuantState`` (None for torchao / unquantized).
|
||||||
|
transpose: If True, return ``[in_features, out_features]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dequantized float tensor, optionally transposed.
|
||||||
|
"""
|
||||||
|
# torchao path: tensor subclass with embedded quantization state
|
||||||
|
if quant_state is None and type(W) is not torch.Tensor:
|
||||||
|
result = None
|
||||||
|
# NF4Tensor (check first — NF4Tensor.dequantize is a static method)
|
||||||
|
if hasattr(W, "get_original_weight"):
|
||||||
|
result = W.get_original_weight()
|
||||||
|
else:
|
||||||
|
# AffineQuantizedTensor (INT4, etc.)
|
||||||
|
try:
|
||||||
|
result = W.dequantize()
|
||||||
|
except (TypeError, RuntimeError):
|
||||||
|
pass
|
||||||
|
if result is not None:
|
||||||
|
return result.t() if transpose else result
|
||||||
|
|
||||||
|
# bnb path: transpose input before the CUDA kernel (existing convention)
|
||||||
|
if transpose:
|
||||||
|
return dequantize(W.t(), quant_state)
|
||||||
|
return dequantize(W, quant_state)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from bitsandbytes.functional import QuantState
|
from bitsandbytes.functional import QuantState
|
||||||
|
|
||||||
from axolotl.kernels.quantize import dequantize
|
from axolotl.kernels.quantize import dequantize, dequantize_weight
|
||||||
|
|
||||||
|
|
||||||
def test_dequantize_null_state():
|
def test_dequantize_null_state():
|
||||||
@@ -100,3 +100,18 @@ def test_dequantize_output_tensor():
|
|||||||
|
|
||||||
result = dequantize(W, quant_state, out=out)
|
result = dequantize(W, quant_state, out=out)
|
||||||
assert result is out
|
assert result is out
|
||||||
|
|
||||||
|
|
||||||
|
def test_dequantize_weight_plain_tensor():
|
||||||
|
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
|
||||||
|
W = torch.randn(32, 64)
|
||||||
|
result = dequantize_weight(W, quant_state=None, transpose=False)
|
||||||
|
assert torch.equal(result, W)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dequantize_weight_plain_tensor_transpose():
|
||||||
|
"""Test that dequantize_weight transposes unquantized tensors"""
|
||||||
|
W = torch.randn(32, 64)
|
||||||
|
result = dequantize_weight(W, quant_state=None, transpose=True)
|
||||||
|
assert result.shape == (64, 32)
|
||||||
|
assert torch.equal(result, W.t())
|
||||||
|
|||||||
Reference in New Issue
Block a user