feat: handle lora kernels compat with torchao

This commit is contained in:
NanoCode012
2026-02-16 21:25:50 +07:00
parent 60c0a828cc
commit 1f7f5e7c26
3 changed files with 78 additions and 12 deletions

View File

@@ -3,7 +3,7 @@
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
from axolotl.kernels.quantize import dequantize, dequantize_weight
def test_dequantize_null_state():
@@ -100,3 +100,18 @@ def test_dequantize_output_tensor():
result = dequantize(W, quant_state, out=out)
assert result is out
def test_dequantize_weight_plain_tensor():
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=False)
assert torch.equal(result, W)
def test_dequantize_weight_plain_tensor_transpose():
"""Test that dequantize_weight transposes unquantized tensors"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=True)
assert result.shape == (64, 32)
assert torch.equal(result, W.t())