Files
axolotl/tests/test_triton_kernels.py
Wing Lian 163bd4dd5a use custom triton kernels for entropy from logits and selective softmax (#3510)
* use custom triton kernels for entropy from logits and selective softmax

* PR comments fixes

* fix out of bounds, include tests, include benchmarks

* chore: lint
2026-03-19 02:02:43 -04:00

482 lines
18 KiB
Python

# Copyright 2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
"""Unit tests for Triton kernels: entropy_from_logits and selective_log_softmax.
Adapted from harness/test_entropy.py and harness/test_selective_logsoftmax.py
into proper pytest tests, plus new OOB index safety tests.
"""
import math
import pytest
import torch
import torch.nn.functional as F
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
)
# ---------------------------------------------------------------------------
# Reference implementations
# ---------------------------------------------------------------------------
def _ref_entropy(logits):
"""Reference entropy via log_softmax (numerically stable)."""
logp = F.log_softmax(logits.float(), dim=-1)
return -(logp.exp() * logp).sum(dim=-1)
def _ref_selective_log_softmax(logits, index):
"""Reference selective log softmax via PyTorch gather."""
squeeze = index.ndim == logits.ndim - 1
if squeeze:
index = index.unsqueeze(-1)
log_probs = F.log_softmax(logits.float(), dim=-1)
result = torch.gather(log_probs, dim=-1, index=index)
if squeeze:
result = result.squeeze(-1)
return result
# ---------------------------------------------------------------------------
# entropy_from_logits
# ---------------------------------------------------------------------------
class TestEntropyFromLogits:
@pytest.mark.parametrize(
"B,L",
[
(1, 128),
(1, 2048),
(4, 512),
(8, 256),
(1, 1),
],
)
def test_correctness_various_shapes(self, B, L):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 1024
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
assert result.shape == (B, L)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(16, 256, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
assert result.shape == (16,)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_large_vocab(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 32000
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_uniform_distribution(self):
"""Uniform logits -> entropy = log(V)."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 1024
logits = torch.zeros(2, V, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected_val = math.log(V)
torch.testing.assert_close(
result,
torch.full((2,), expected_val, device="cuda", dtype=torch.float32),
atol=1e-4,
rtol=1e-4,
)
def test_peaked_distribution(self):
"""One-hot-like logits -> entropy near 0."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.full((2, 128), -100.0, device="cuda", dtype=torch.float32)
logits[:, 0] = 100.0
result = entropy_from_logits(logits)
assert (result < 1e-3).all()
def test_bfloat16(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(4, 256, device="cuda", dtype=torch.bfloat16)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits.float())
assert result.dtype == torch.bfloat16
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
def test_float16(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(4, 256, device="cuda", dtype=torch.float16)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits.float())
assert result.dtype == torch.float16
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
def test_non_contiguous_3d_transpose(self):
"""Non-contiguous 3D tensor via transpose(0,1)."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 256
raw = torch.randn(32, 4, V, device="cuda", dtype=torch.float32)
logits = raw.transpose(0, 1) # (4, 32, V) non-contiguous
assert not logits.is_contiguous()
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_non_contiguous_3d_slice(self):
"""Non-contiguous 3D tensor via batch slicing."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 256
raw = torch.randn(8, 32, V, device="cuda", dtype=torch.float32)
logits = raw[::2] # (4, 32, V) non-contiguous
assert not logits.is_contiguous()
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_many_rows_beyond_max_grid(self):
"""More rows than MAX_GRID (8192) to test chunked dispatch."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(10000, 128, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_entropy_non_negative(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(32, 512, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
assert (result >= -1e-5).all(), f"Negative entropy: {result.min()}"
# ---------------------------------------------------------------------------
# selective_log_softmax — forward correctness
# ---------------------------------------------------------------------------
class TestSelectiveLogSoftmax:
@pytest.mark.parametrize(
"B,L,K",
[
(1, 128, 1),
(4, 512, 1),
(8, 256, 1),
(4, 256, 4),
(4, 256, 7),
(15, 129, 1), # non-power-of-2
],
)
def test_correctness_various_shapes(self, B, L, K):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 1024
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
if K == 1:
index = torch.randint(0, V, (B, L), device="cuda")
else:
index = torch.randint(0, V, (B, L, K), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_squeezed_index(self):
"""Index with ndim == logits.ndim - 1 triggers squeeze path."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (8,), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
assert result.shape == (8,)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_large_vocab(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 32000
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (2, 1), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_bfloat16(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 1024
torch.manual_seed(42)
logits = torch.randn(4, 128, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (4, 128), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits.float(), index)
assert result.dtype == torch.bfloat16
torch.testing.assert_close(result.float(), expected, atol=0.1, rtol=0.1)
def test_fp32_tight_tolerance(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 1024
torch.manual_seed(42)
logits = torch.randn(2, 256, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (2, 256), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
def test_all_same_index(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
index = torch.zeros(8, 1, device="cuda", dtype=torch.long)
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_last_index(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
index = torch.full((8, 1), V - 1, device="cuda", dtype=torch.long)
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_output_always_nonpositive(self):
"""Log softmax values should always be <= 0."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
logits = torch.randn(32, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (32, 1), device="cuda")
result = selective_log_softmax(logits, index)
assert (result <= 1e-5).all(), f"Positive log-prob: {result.max()}"
def test_many_rows_beyond_max_grid(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(10000, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (10000, 1), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
# ---------------------------------------------------------------------------
# selective_log_softmax — backward / gradient correctness
# ---------------------------------------------------------------------------
class TestSelectiveLogSoftmaxBackward:
@pytest.mark.parametrize(
"B,L,V,K",
[
(2, 16, 64, 1),
(2, 16, 64, 4),
(1, 8, 128, 1),
(2, 8, 128, 7),
],
)
def test_gradient_matches_reference(self, B, L, V, K):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
torch.manual_seed(42)
logits_ref = torch.randn(
B, L, V, device="cuda", dtype=torch.float32, requires_grad=True
)
logits_tri = logits_ref.detach().clone().requires_grad_(True)
if K == 1:
index = torch.randint(0, V, (B, L), device="cuda")
else:
index = torch.randint(0, V, (B, L, K), device="cuda")
ref_out = _ref_selective_log_softmax(logits_ref, index)
tri_out = selective_log_softmax(logits_tri, index)
ref_out.sum().backward()
tri_out.sum().backward()
torch.testing.assert_close(
logits_tri.grad, logits_ref.grad, atol=1e-5, rtol=1e-5
)
def test_gradient_bfloat16_full_vocab(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 4096
torch.manual_seed(42)
logits_ref = torch.randn(
2, 64, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_ref.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (2, 64), device="cuda")
_ref_selective_log_softmax(logits_ref, index).sum().backward()
selective_log_softmax(logits_tri, index).sum().backward()
torch.testing.assert_close(
logits_tri.grad.float(), logits_ref.grad.float(), atol=0.1, rtol=0.1
)
def test_gradient_k1_squeezed(self):
"""Gradient with squeezed (1D) index."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
logits = torch.randn(
8, V, device="cuda", dtype=torch.float32, requires_grad=True
)
index = torch.randint(0, V, (8,), device="cuda")
result = selective_log_softmax(logits, index)
result.sum().backward()
triton_grad = logits.grad.clone()
logits.grad = None
ref = torch.gather(
F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
ref.sum().backward()
torch.testing.assert_close(triton_grad, logits.grad, atol=1e-4, rtol=1e-4)
# ---------------------------------------------------------------------------
# selective_log_softmax — out-of-bounds index safety
# ---------------------------------------------------------------------------
class TestSelectiveLogSoftmaxOOBSafety:
"""Verify that out-of-range indices don't crash or corrupt valid results."""
def test_negative_indices_no_crash(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
index = torch.tensor(
[[-1], [0], [V - 1], [-5]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
assert result.shape == (4, 1)
# Valid rows should be finite and match reference
valid_idx = torch.tensor([[0], [V - 1]], device="cuda", dtype=torch.long)
valid_logits = logits[1:3]
expected = _ref_selective_log_softmax(valid_logits, valid_idx)
torch.testing.assert_close(result[1:3], expected, atol=1e-4, rtol=1e-4)
def test_index_exceeds_vocab_no_crash(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
index = torch.tensor(
[[0], [V], [V + 100], [V - 1]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
assert result.shape == (4, 1)
# Valid rows (0 and 3) should match reference
for row_idx, idx_val in [(0, 0), (3, V - 1)]:
ref = _ref_selective_log_softmax(
logits[row_idx : row_idx + 1],
torch.tensor([[idx_val]], device="cuda", dtype=torch.long),
)
torch.testing.assert_close(
result[row_idx : row_idx + 1], ref, atol=1e-4, rtol=1e-4
)
def test_mixed_valid_invalid_multi_index(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
K = 3
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
index = torch.tensor(
[
[0, 10, -1], # last invalid
[V, 5, 100], # first invalid
[50, 60, 70], # all valid
[-1, V + 1, -100], # all invalid
],
device="cuda",
dtype=torch.long,
)
result = selective_log_softmax(logits, index)
assert result.shape == (4, K)
# Row 2 (all valid) must match reference exactly
valid_index = torch.tensor([[50, 60, 70]], device="cuda", dtype=torch.long)
expected = _ref_selective_log_softmax(logits[2:3], valid_index)
torch.testing.assert_close(result[2:3], expected, atol=1e-4, rtol=1e-4)
def test_oob_backward_no_crash(self):
"""Backward with OOB indices should not crash and grads should be finite."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(
4, V, device="cuda", dtype=torch.float32, requires_grad=True
)
index = torch.tensor(
[[-1], [0], [V + 10], [V - 1]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
result.sum().backward()
assert logits.grad is not None
assert torch.isfinite(logits.grad).all()
def test_oob_backward_valid_rows_correct(self):
"""Gradients for valid-index rows should match reference even when other rows have OOB."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(
4, V, device="cuda", dtype=torch.float32, requires_grad=True
)
# Row 0: invalid, Row 1: valid, Row 2: invalid, Row 3: valid
index = torch.tensor(
[[-1], [42], [V + 5], [100]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
result.sum().backward()
# Compute reference gradient for valid rows only
logits_ref = logits.detach().clone().requires_grad_(True)
valid_rows = [1, 3]
valid_indices = [42, 100]
for r, idx in zip(valid_rows, valid_indices, strict=True):
ref_lp = F.log_softmax(logits_ref[r : r + 1], dim=-1)
ref_val = ref_lp[0, idx]
ref_val.backward(retain_graph=True)
for r in valid_rows:
torch.testing.assert_close(
logits.grad[r], logits_ref.grad[r], atol=1e-4, rtol=1e-4
)