feat: handle lora kernels compat with torchao

This commit is contained in:
NanoCode012
2026-02-16 21:25:50 +07:00
parent 60c0a828cc
commit 1f7f5e7c26
3 changed files with 78 additions and 12 deletions

View File

@@ -15,7 +15,7 @@ from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .quantize import dequantize_weight
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
@@ -46,6 +46,12 @@ def get_lora_parameters(
W = base_layer.weight
b = base_layer.bias
# Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy
# attribute access to the underlying tensor subclass, so torchao methods like
# .dequantize() or .get_original_weight() would not be visible.
if isinstance(W, DTensor):
W = W.full_tensor()
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None
@@ -86,6 +92,7 @@ def matmul_lora(
B: torch.Tensor | None,
s: float | None,
out: torch.Tensor | None = None,
transpose: bool = True,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
@@ -98,12 +105,15 @@ def matmul_lora(
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
transpose: If True (default), transpose W before matmul (forward path).
Set to False for backward paths where W is already in the correct layout.
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
is_quantized = W_quant is not None or type(W) is not torch.Tensor
W = dequantize_weight(W, W_quant, transpose=transpose)
reshape = False
if X.dim() == 3:
@@ -112,7 +122,7 @@ def matmul_lora(
reshape = True
out = torch.matmul(X, W, out=out)
if W_quant is not None:
if is_quantized:
del W
if A is not None:
@@ -292,15 +302,16 @@ class LoRA_MLP(torch.autograd.Function):
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
# Down projection (backward: no transpose needed, W is already [out, in])
grad_down = matmul_lora(
grad_output,
down_weight.t(),
down_weight,
None,
down_quant,
down_B,
down_A,
down_scale,
transpose=False,
)
# Activation backward
@@ -332,7 +343,7 @@ class LoRA_MLP(torch.autograd.Function):
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
up_weight = dequantize_weight(up_weight, up_quant, transpose=True)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
@@ -344,7 +355,7 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight, gate_quant)
gate_weight = dequantize_weight(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
del gate_weight
@@ -631,7 +642,7 @@ class LoRA_QKV(torch.autograd.Function):
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
q_weight_t = dequantize_weight(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
@@ -639,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
k_weight_t = dequantize_weight(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
@@ -647,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
v_weight_t = dequantize_weight(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
@@ -810,7 +821,7 @@ class LoRA_O(torch.autograd.Function):
d_B = s * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
W = dequantize_weight(W, W_quant, transpose=True)
dX = dY @ W.t()
del W

View File

@@ -146,3 +146,43 @@ def dequantize(
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out
def dequantize_weight(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
transpose: bool = False,
) -> torch.Tensor:
"""Unified dequantization for both torchao and bnb quantized weights.
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
using the appropriate instance method. For bnb Params4bit, delegates to the
optimized CUDA kernel in ``dequantize``.
Args:
W: Quantized weight tensor ``[out_features, in_features]``.
quant_state: bnb ``QuantState`` (None for torchao / unquantized).
transpose: If True, return ``[in_features, out_features]``.
Returns:
Dequantized float tensor, optionally transposed.
"""
# torchao path: tensor subclass with embedded quantization state
if quant_state is None and type(W) is not torch.Tensor:
result = None
# NF4Tensor (check first — NF4Tensor.dequantize is a static method)
if hasattr(W, "get_original_weight"):
result = W.get_original_weight()
else:
# AffineQuantizedTensor (INT4, etc.)
try:
result = W.dequantize()
except (TypeError, RuntimeError):
pass
if result is not None:
return result.t() if transpose else result
# bnb path: transpose input before the CUDA kernel (existing convention)
if transpose:
return dequantize(W.t(), quant_state)
return dequantize(W, quant_state)

View File

@@ -3,7 +3,7 @@
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
from axolotl.kernels.quantize import dequantize, dequantize_weight
def test_dequantize_null_state():
@@ -100,3 +100,18 @@ def test_dequantize_output_tensor():
result = dequantize(W, quant_state, out=out)
assert result is out
def test_dequantize_weight_plain_tensor():
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=False)
assert torch.equal(result, W)
def test_dequantize_weight_plain_tensor_transpose():
"""Test that dequantize_weight transposes unquantized tensors"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=True)
assert result.shape == (64, 32)
assert torch.equal(result, W.t())