feat: handle lora kernels compat with torchao

This commit is contained in:
NanoCode012
2026-02-16 21:25:50 +07:00
parent 60c0a828cc
commit 1f7f5e7c26
3 changed files with 78 additions and 12 deletions

View File

@@ -15,7 +15,7 @@ from torch import nn
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize from .quantize import dequantize_weight
from .swiglu import swiglu_backward, swiglu_forward from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
@@ -46,6 +46,12 @@ def get_lora_parameters(
W = base_layer.weight W = base_layer.weight
b = base_layer.bias b = base_layer.bias
# Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy
# attribute access to the underlying tensor subclass, so torchao methods like
# .dequantize() or .get_original_weight() would not be visible.
if isinstance(W, DTensor):
W = W.full_tensor()
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None) quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None return W, b, quant_state, None, None, None
@@ -86,6 +92,7 @@ def matmul_lora(
B: torch.Tensor | None, B: torch.Tensor | None,
s: float | None, s: float | None,
out: torch.Tensor | None = None, out: torch.Tensor | None = None,
transpose: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Efficient fused matmul + LoRA computation. Efficient fused matmul + LoRA computation.
@@ -98,12 +105,15 @@ def matmul_lora(
B: LoRA B matrix [out_features, rank] B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor s: LoRA scaling factor
out: Optional output tensor for inplace operations out: Optional output tensor for inplace operations
transpose: If True (default), transpose W before matmul (forward path).
Set to False for backward paths where W is already in the correct layout.
Returns: Returns:
Result of X @ W + X @ A @ B Result of X @ W + X @ A @ B
""" """
dtype = X.dtype dtype = X.dtype
W = dequantize(W.t(), W_quant) is_quantized = W_quant is not None or type(W) is not torch.Tensor
W = dequantize_weight(W, W_quant, transpose=transpose)
reshape = False reshape = False
if X.dim() == 3: if X.dim() == 3:
@@ -112,7 +122,7 @@ def matmul_lora(
reshape = True reshape = True
out = torch.matmul(X, W, out=out) out = torch.matmul(X, W, out=out)
if W_quant is not None: if is_quantized:
del W del W
if A is not None: if A is not None:
@@ -292,15 +302,16 @@ class LoRA_MLP(torch.autograd.Function):
up = up.view(-1, up.shape[-1]) up = up.view(-1, up.shape[-1])
dtype = X.dtype dtype = X.dtype
# Down projection # Down projection (backward: no transpose needed, W is already [out, in])
grad_down = matmul_lora( grad_down = matmul_lora(
grad_output, grad_output,
down_weight.t(), down_weight,
None, None,
down_quant, down_quant,
down_B, down_B,
down_A, down_A,
down_scale, down_scale,
transpose=False,
) )
# Activation backward # Activation backward
@@ -332,7 +343,7 @@ class LoRA_MLP(torch.autograd.Function):
if dX is not None: if dX is not None:
# Up projection gradients # Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant) up_weight = dequantize_weight(up_weight, up_quant, transpose=True)
if ctx.inplace: if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X) dX = torch.matmul(grad_up, up_weight.t(), out=X)
else: else:
@@ -344,7 +355,7 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients # Gate projection gradients
gate_weight = dequantize(gate_weight, gate_quant) gate_weight = dequantize_weight(gate_weight, gate_quant)
dX += grad_gate @ gate_weight dX += grad_gate @ gate_weight
del gate_weight del gate_weight
@@ -631,7 +642,7 @@ class LoRA_QKV(torch.autograd.Function):
out_buffer = X if ctx.inplace else None out_buffer = X if ctx.inplace else None
# Q path # Q path
q_weight_t = dequantize(q_weight, q_quant) q_weight_t = dequantize_weight(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer) grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight del q_weight
del q_weight_t del q_weight_t
@@ -639,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled)) grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path # K path
k_weight_t = dequantize(k_weight, k_quant) k_weight_t = dequantize_weight(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t) grad_X.addmm_(k_grad, k_weight_t)
del k_weight del k_weight
del k_weight_t del k_weight_t
@@ -647,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled)) grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path # V path
v_weight_t = dequantize(v_weight, v_quant) v_weight_t = dequantize_weight(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t) grad_X.addmm_(v_grad, v_weight_t)
del v_weight del v_weight
del v_weight_t del v_weight_t
@@ -810,7 +821,7 @@ class LoRA_O(torch.autograd.Function):
d_B = s * A @ dY_X d_B = s * A @ dY_X
# Get derivative for dX # Get derivative for dX
W = dequantize(W.t(), W_quant) W = dequantize_weight(W, W_quant, transpose=True)
dX = dY @ W.t() dX = dY @ W.t()
del W del W

View File

@@ -146,3 +146,43 @@ def dequantize(
# Handle transposed data # Handle transposed data
is_transposed: bool = W.shape[0] == 1 is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out return out.t() if is_transposed else out
def dequantize_weight(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
transpose: bool = False,
) -> torch.Tensor:
"""Unified dequantization for both torchao and bnb quantized weights.
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
using the appropriate instance method. For bnb Params4bit, delegates to the
optimized CUDA kernel in ``dequantize``.
Args:
W: Quantized weight tensor ``[out_features, in_features]``.
quant_state: bnb ``QuantState`` (None for torchao / unquantized).
transpose: If True, return ``[in_features, out_features]``.
Returns:
Dequantized float tensor, optionally transposed.
"""
# torchao path: tensor subclass with embedded quantization state
if quant_state is None and type(W) is not torch.Tensor:
result = None
# NF4Tensor (check first — NF4Tensor.dequantize is a static method)
if hasattr(W, "get_original_weight"):
result = W.get_original_weight()
else:
# AffineQuantizedTensor (INT4, etc.)
try:
result = W.dequantize()
except (TypeError, RuntimeError):
pass
if result is not None:
return result.t() if transpose else result
# bnb path: transpose input before the CUDA kernel (existing convention)
if transpose:
return dequantize(W.t(), quant_state)
return dequantize(W, quant_state)

View File

@@ -3,7 +3,7 @@
import torch import torch
from bitsandbytes.functional import QuantState from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize from axolotl.kernels.quantize import dequantize, dequantize_weight
def test_dequantize_null_state(): def test_dequantize_null_state():
@@ -100,3 +100,18 @@ def test_dequantize_output_tensor():
result = dequantize(W, quant_state, out=out) result = dequantize(W, quant_state, out=out)
assert result is out assert result is out
def test_dequantize_weight_plain_tensor():
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=False)
assert torch.equal(result, W)
def test_dequantize_weight_plain_tensor_transpose():
"""Test that dequantize_weight transposes unquantized tensors"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=True)
assert result.shape == (64, 32)
assert torch.equal(result, W.t())