Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit * implementing optims * finalizing MLP LoRA kernels and progress on QKV / W kernels * updates * O projection optim * adding monkey patching logic * doc strings, typing, pre-commit fixes * updates * adding lora 8b kernels example * working on fsdp support * tests and fixes * small fixes, getting tests to pass, adding doc strings * integration tests for LoRA patching * config.qmd * remove unneeded pytest fixture * fix * review comments first pass * improving tests, attention class agnostic patching * adding support for more archs * wip SiLU / GELU impls * improved testing, small updates, etc. * slightly updating docs * rebase * fixing test_attention_patching_integration * additional review comments, fixing test in CI (hopefully) * isolating problematic patching test * relaxing allclose threshold to reduce flakiness * fixing accidental change * adding model arch agnostic attention class fetching * removing unused activations
This commit is contained in:
103
tests/e2e/kernels/test_quantize.py
Normal file
103
tests/e2e/kernels/test_quantize.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Tests for quantization utility functions."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
from axolotl.kernels.quantize import dequantize
|
||||
|
||||
|
||||
def test_dequantize_null_state():
|
||||
"""Test that dequantize returns input unchanged when quant_state is None"""
|
||||
W = torch.randn(32, 32)
|
||||
assert torch.equal(dequantize(W, None), W)
|
||||
|
||||
|
||||
def test_dequantize_shape_preservation():
|
||||
"""Test that dequantization preserves expected shapes"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape == shape
|
||||
assert result.dtype == torch.float16
|
||||
assert result.device == W.device
|
||||
|
||||
|
||||
def test_dequantize_transposed():
|
||||
"""Test that transposed input produces transposed output"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(1, shape[1], device="cuda") # Transposed input
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(1, dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape[0] == shape[0]
|
||||
|
||||
|
||||
def test_dequantize_output_tensor():
|
||||
"""Test dequantization with provided output tensor"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
out = torch.empty(shape, dtype=torch.float16, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state, out=out)
|
||||
assert result is out
|
||||
Reference in New Issue
Block a user