From 1f7f5e7c269b2ced5a5be8125fd21ad31ba59940 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 16 Feb 2026 21:25:50 +0700 Subject: [PATCH] feat: handle lora kernels compat with torchao --- src/axolotl/kernels/lora.py | 33 ++++++++++++++++-------- src/axolotl/kernels/quantize.py | 40 ++++++++++++++++++++++++++++++ tests/e2e/kernels/test_quantize.py | 17 ++++++++++++- 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index c3356fb90..919305931 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -15,7 +15,7 @@ from torch import nn from torch.distributed.tensor import DTensor from .geglu import geglu_backward, geglu_forward -from .quantize import dequantize +from .quantize import dequantize_weight from .swiglu import swiglu_backward, swiglu_forward from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd @@ -46,6 +46,12 @@ def get_lora_parameters( W = base_layer.weight b = base_layer.bias + # Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy + # attribute access to the underlying tensor subclass, so torchao methods like + # .dequantize() or .get_original_weight() would not be visible. + if isinstance(W, DTensor): + W = W.full_tensor() + if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: quant_state = getattr(W, "quant_state", None) return W, b, quant_state, None, None, None @@ -86,6 +92,7 @@ def matmul_lora( B: torch.Tensor | None, s: float | None, out: torch.Tensor | None = None, + transpose: bool = True, ) -> torch.Tensor: """ Efficient fused matmul + LoRA computation. @@ -98,12 +105,15 @@ def matmul_lora( B: LoRA B matrix [out_features, rank] s: LoRA scaling factor out: Optional output tensor for inplace operations + transpose: If True (default), transpose W before matmul (forward path). + Set to False for backward paths where W is already in the correct layout. Returns: Result of X @ W + X @ A @ B """ dtype = X.dtype - W = dequantize(W.t(), W_quant) + is_quantized = W_quant is not None or type(W) is not torch.Tensor + W = dequantize_weight(W, W_quant, transpose=transpose) reshape = False if X.dim() == 3: @@ -112,7 +122,7 @@ def matmul_lora( reshape = True out = torch.matmul(X, W, out=out) - if W_quant is not None: + if is_quantized: del W if A is not None: @@ -292,15 +302,16 @@ class LoRA_MLP(torch.autograd.Function): up = up.view(-1, up.shape[-1]) dtype = X.dtype - # Down projection + # Down projection (backward: no transpose needed, W is already [out, in]) grad_down = matmul_lora( grad_output, - down_weight.t(), + down_weight, None, down_quant, down_B, down_A, down_scale, + transpose=False, ) # Activation backward @@ -332,7 +343,7 @@ class LoRA_MLP(torch.autograd.Function): if dX is not None: # Up projection gradients - up_weight = dequantize(up_weight.t(), up_quant) + up_weight = dequantize_weight(up_weight, up_quant, transpose=True) if ctx.inplace: dX = torch.matmul(grad_up, up_weight.t(), out=X) else: @@ -344,7 +355,7 @@ class LoRA_MLP(torch.autograd.Function): dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) # Gate projection gradients - gate_weight = dequantize(gate_weight, gate_quant) + gate_weight = dequantize_weight(gate_weight, gate_quant) dX += grad_gate @ gate_weight del gate_weight @@ -631,7 +642,7 @@ class LoRA_QKV(torch.autograd.Function): out_buffer = X if ctx.inplace else None # Q path - q_weight_t = dequantize(q_weight, q_quant) + q_weight_t = dequantize_weight(q_weight, q_quant) grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer) del q_weight del q_weight_t @@ -639,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function): grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled)) # K path - k_weight_t = dequantize(k_weight, k_quant) + k_weight_t = dequantize_weight(k_weight, k_quant) grad_X.addmm_(k_grad, k_weight_t) del k_weight del k_weight_t @@ -647,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function): grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled)) # V path - v_weight_t = dequantize(v_weight, v_quant) + v_weight_t = dequantize_weight(v_weight, v_quant) grad_X.addmm_(v_grad, v_weight_t) del v_weight del v_weight_t @@ -810,7 +821,7 @@ class LoRA_O(torch.autograd.Function): d_B = s * A @ dY_X # Get derivative for dX - W = dequantize(W.t(), W_quant) + W = dequantize_weight(W, W_quant, transpose=True) dX = dY @ W.t() del W diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py index d094f2381..3f6ddded7 100644 --- a/src/axolotl/kernels/quantize.py +++ b/src/axolotl/kernels/quantize.py @@ -146,3 +146,43 @@ def dequantize( # Handle transposed data is_transposed: bool = W.shape[0] == 1 return out.t() if is_transposed else out + + +def dequantize_weight( + W: torch.Tensor, + quant_state: QuantState | list | None = None, + transpose: bool = False, +) -> torch.Tensor: + """Unified dequantization for both torchao and bnb quantized weights. + + For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes + using the appropriate instance method. For bnb Params4bit, delegates to the + optimized CUDA kernel in ``dequantize``. + + Args: + W: Quantized weight tensor ``[out_features, in_features]``. + quant_state: bnb ``QuantState`` (None for torchao / unquantized). + transpose: If True, return ``[in_features, out_features]``. + + Returns: + Dequantized float tensor, optionally transposed. + """ + # torchao path: tensor subclass with embedded quantization state + if quant_state is None and type(W) is not torch.Tensor: + result = None + # NF4Tensor (check first — NF4Tensor.dequantize is a static method) + if hasattr(W, "get_original_weight"): + result = W.get_original_weight() + else: + # AffineQuantizedTensor (INT4, etc.) + try: + result = W.dequantize() + except (TypeError, RuntimeError): + pass + if result is not None: + return result.t() if transpose else result + + # bnb path: transpose input before the CUDA kernel (existing convention) + if transpose: + return dequantize(W.t(), quant_state) + return dequantize(W, quant_state) diff --git a/tests/e2e/kernels/test_quantize.py b/tests/e2e/kernels/test_quantize.py index 60396584c..a93f2119c 100644 --- a/tests/e2e/kernels/test_quantize.py +++ b/tests/e2e/kernels/test_quantize.py @@ -3,7 +3,7 @@ import torch from bitsandbytes.functional import QuantState -from axolotl.kernels.quantize import dequantize +from axolotl.kernels.quantize import dequantize, dequantize_weight def test_dequantize_null_state(): @@ -100,3 +100,18 @@ def test_dequantize_output_tensor(): result = dequantize(W, quant_state, out=out) assert result is out + + +def test_dequantize_weight_plain_tensor(): + """Test that dequantize_weight passes through unquantized tensors unchanged""" + W = torch.randn(32, 64) + result = dequantize_weight(W, quant_state=None, transpose=False) + assert torch.equal(result, W) + + +def test_dequantize_weight_plain_tensor_transpose(): + """Test that dequantize_weight transposes unquantized tensors""" + W = torch.randn(32, 64) + result = dequantize_weight(W, quant_state=None, transpose=True) + assert result.shape == (64, 32) + assert torch.equal(result, W.t())