* use custom triton kernels for entropy from logits and selective softmax * PR comments fixes * fix out of bounds, include tests, include benchmarks * chore: lint
482 lines
18 KiB
Python
482 lines
18 KiB
Python
# Copyright 2026 Axolotl AI. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
"""Unit tests for Triton kernels: entropy_from_logits and selective_log_softmax.
|
|
|
|
Adapted from harness/test_entropy.py and harness/test_selective_logsoftmax.py
|
|
into proper pytest tests, plus new OOB index safety tests.
|
|
"""
|
|
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
pytestmark = pytest.mark.skipif(
|
|
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reference implementations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _ref_entropy(logits):
|
|
"""Reference entropy via log_softmax (numerically stable)."""
|
|
logp = F.log_softmax(logits.float(), dim=-1)
|
|
return -(logp.exp() * logp).sum(dim=-1)
|
|
|
|
|
|
def _ref_selective_log_softmax(logits, index):
|
|
"""Reference selective log softmax via PyTorch gather."""
|
|
squeeze = index.ndim == logits.ndim - 1
|
|
if squeeze:
|
|
index = index.unsqueeze(-1)
|
|
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
result = torch.gather(log_probs, dim=-1, index=index)
|
|
if squeeze:
|
|
result = result.squeeze(-1)
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# entropy_from_logits
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEntropyFromLogits:
|
|
@pytest.mark.parametrize(
|
|
"B,L",
|
|
[
|
|
(1, 128),
|
|
(1, 2048),
|
|
(4, 512),
|
|
(8, 256),
|
|
(1, 1),
|
|
],
|
|
)
|
|
def test_correctness_various_shapes(self, B, L):
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
V = 1024
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits)
|
|
assert result.shape == (B, L)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_2d_input(self):
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
logits = torch.randn(16, 256, device="cuda", dtype=torch.float32)
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits)
|
|
assert result.shape == (16,)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_large_vocab(self):
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
V = 32000
|
|
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_uniform_distribution(self):
|
|
"""Uniform logits -> entropy = log(V)."""
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
V = 1024
|
|
logits = torch.zeros(2, V, device="cuda", dtype=torch.float32)
|
|
result = entropy_from_logits(logits)
|
|
expected_val = math.log(V)
|
|
torch.testing.assert_close(
|
|
result,
|
|
torch.full((2,), expected_val, device="cuda", dtype=torch.float32),
|
|
atol=1e-4,
|
|
rtol=1e-4,
|
|
)
|
|
|
|
def test_peaked_distribution(self):
|
|
"""One-hot-like logits -> entropy near 0."""
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
logits = torch.full((2, 128), -100.0, device="cuda", dtype=torch.float32)
|
|
logits[:, 0] = 100.0
|
|
result = entropy_from_logits(logits)
|
|
assert (result < 1e-3).all()
|
|
|
|
def test_bfloat16(self):
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
logits = torch.randn(4, 256, device="cuda", dtype=torch.bfloat16)
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits.float())
|
|
assert result.dtype == torch.bfloat16
|
|
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
|
|
|
|
def test_float16(self):
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
logits = torch.randn(4, 256, device="cuda", dtype=torch.float16)
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits.float())
|
|
assert result.dtype == torch.float16
|
|
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
|
|
|
|
def test_non_contiguous_3d_transpose(self):
|
|
"""Non-contiguous 3D tensor via transpose(0,1)."""
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
V = 256
|
|
raw = torch.randn(32, 4, V, device="cuda", dtype=torch.float32)
|
|
logits = raw.transpose(0, 1) # (4, 32, V) non-contiguous
|
|
assert not logits.is_contiguous()
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_non_contiguous_3d_slice(self):
|
|
"""Non-contiguous 3D tensor via batch slicing."""
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
V = 256
|
|
raw = torch.randn(8, 32, V, device="cuda", dtype=torch.float32)
|
|
logits = raw[::2] # (4, 32, V) non-contiguous
|
|
assert not logits.is_contiguous()
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_many_rows_beyond_max_grid(self):
|
|
"""More rows than MAX_GRID (8192) to test chunked dispatch."""
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
logits = torch.randn(10000, 128, device="cuda", dtype=torch.float32)
|
|
result = entropy_from_logits(logits)
|
|
expected = _ref_entropy(logits)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_entropy_non_negative(self):
|
|
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
|
|
|
logits = torch.randn(32, 512, device="cuda", dtype=torch.float32)
|
|
result = entropy_from_logits(logits)
|
|
assert (result >= -1e-5).all(), f"Negative entropy: {result.min()}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# selective_log_softmax — forward correctness
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSelectiveLogSoftmax:
|
|
@pytest.mark.parametrize(
|
|
"B,L,K",
|
|
[
|
|
(1, 128, 1),
|
|
(4, 512, 1),
|
|
(8, 256, 1),
|
|
(4, 256, 4),
|
|
(4, 256, 7),
|
|
(15, 129, 1), # non-power-of-2
|
|
],
|
|
)
|
|
def test_correctness_various_shapes(self, B, L, K):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 1024
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
|
|
if K == 1:
|
|
index = torch.randint(0, V, (B, L), device="cuda")
|
|
else:
|
|
index = torch.randint(0, V, (B, L, K), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_squeezed_index(self):
|
|
"""Index with ndim == logits.ndim - 1 triggers squeeze path."""
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 256
|
|
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
|
|
index = torch.randint(0, V, (8,), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
assert result.shape == (8,)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_large_vocab(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 32000
|
|
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
|
|
index = torch.randint(0, V, (2, 1), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_bfloat16(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 1024
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(4, 128, V, device="cuda", dtype=torch.bfloat16)
|
|
index = torch.randint(0, V, (4, 128), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits.float(), index)
|
|
assert result.dtype == torch.bfloat16
|
|
torch.testing.assert_close(result.float(), expected, atol=0.1, rtol=0.1)
|
|
|
|
def test_fp32_tight_tolerance(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 1024
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(2, 256, V, device="cuda", dtype=torch.float32)
|
|
index = torch.randint(0, V, (2, 256), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
|
|
|
|
def test_all_same_index(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
|
|
index = torch.zeros(8, 1, device="cuda", dtype=torch.long)
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_last_index(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
|
|
index = torch.full((8, 1), V - 1, device="cuda", dtype=torch.long)
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_output_always_nonpositive(self):
|
|
"""Log softmax values should always be <= 0."""
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 256
|
|
logits = torch.randn(32, V, device="cuda", dtype=torch.float32)
|
|
index = torch.randint(0, V, (32, 1), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
assert (result <= 1e-5).all(), f"Positive log-prob: {result.max()}"
|
|
|
|
def test_many_rows_beyond_max_grid(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(10000, V, device="cuda", dtype=torch.float32)
|
|
index = torch.randint(0, V, (10000, 1), device="cuda")
|
|
result = selective_log_softmax(logits, index)
|
|
expected = _ref_selective_log_softmax(logits, index)
|
|
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# selective_log_softmax — backward / gradient correctness
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSelectiveLogSoftmaxBackward:
|
|
@pytest.mark.parametrize(
|
|
"B,L,V,K",
|
|
[
|
|
(2, 16, 64, 1),
|
|
(2, 16, 64, 4),
|
|
(1, 8, 128, 1),
|
|
(2, 8, 128, 7),
|
|
],
|
|
)
|
|
def test_gradient_matches_reference(self, B, L, V, K):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
torch.manual_seed(42)
|
|
logits_ref = torch.randn(
|
|
B, L, V, device="cuda", dtype=torch.float32, requires_grad=True
|
|
)
|
|
logits_tri = logits_ref.detach().clone().requires_grad_(True)
|
|
|
|
if K == 1:
|
|
index = torch.randint(0, V, (B, L), device="cuda")
|
|
else:
|
|
index = torch.randint(0, V, (B, L, K), device="cuda")
|
|
|
|
ref_out = _ref_selective_log_softmax(logits_ref, index)
|
|
tri_out = selective_log_softmax(logits_tri, index)
|
|
|
|
ref_out.sum().backward()
|
|
tri_out.sum().backward()
|
|
|
|
torch.testing.assert_close(
|
|
logits_tri.grad, logits_ref.grad, atol=1e-5, rtol=1e-5
|
|
)
|
|
|
|
def test_gradient_bfloat16_full_vocab(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 4096
|
|
torch.manual_seed(42)
|
|
logits_ref = torch.randn(
|
|
2, 64, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
logits_tri = logits_ref.detach().clone().requires_grad_(True)
|
|
index = torch.randint(0, V, (2, 64), device="cuda")
|
|
|
|
_ref_selective_log_softmax(logits_ref, index).sum().backward()
|
|
selective_log_softmax(logits_tri, index).sum().backward()
|
|
|
|
torch.testing.assert_close(
|
|
logits_tri.grad.float(), logits_ref.grad.float(), atol=0.1, rtol=0.1
|
|
)
|
|
|
|
def test_gradient_k1_squeezed(self):
|
|
"""Gradient with squeezed (1D) index."""
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 256
|
|
logits = torch.randn(
|
|
8, V, device="cuda", dtype=torch.float32, requires_grad=True
|
|
)
|
|
index = torch.randint(0, V, (8,), device="cuda")
|
|
|
|
result = selective_log_softmax(logits, index)
|
|
result.sum().backward()
|
|
triton_grad = logits.grad.clone()
|
|
|
|
logits.grad = None
|
|
ref = torch.gather(
|
|
F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
ref.sum().backward()
|
|
|
|
torch.testing.assert_close(triton_grad, logits.grad, atol=1e-4, rtol=1e-4)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# selective_log_softmax — out-of-bounds index safety
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSelectiveLogSoftmaxOOBSafety:
|
|
"""Verify that out-of-range indices don't crash or corrupt valid results."""
|
|
|
|
def test_negative_indices_no_crash(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
|
|
index = torch.tensor(
|
|
[[-1], [0], [V - 1], [-5]], device="cuda", dtype=torch.long
|
|
)
|
|
result = selective_log_softmax(logits, index)
|
|
assert result.shape == (4, 1)
|
|
# Valid rows should be finite and match reference
|
|
valid_idx = torch.tensor([[0], [V - 1]], device="cuda", dtype=torch.long)
|
|
valid_logits = logits[1:3]
|
|
expected = _ref_selective_log_softmax(valid_logits, valid_idx)
|
|
torch.testing.assert_close(result[1:3], expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_index_exceeds_vocab_no_crash(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
|
|
index = torch.tensor(
|
|
[[0], [V], [V + 100], [V - 1]], device="cuda", dtype=torch.long
|
|
)
|
|
result = selective_log_softmax(logits, index)
|
|
assert result.shape == (4, 1)
|
|
# Valid rows (0 and 3) should match reference
|
|
for row_idx, idx_val in [(0, 0), (3, V - 1)]:
|
|
ref = _ref_selective_log_softmax(
|
|
logits[row_idx : row_idx + 1],
|
|
torch.tensor([[idx_val]], device="cuda", dtype=torch.long),
|
|
)
|
|
torch.testing.assert_close(
|
|
result[row_idx : row_idx + 1], ref, atol=1e-4, rtol=1e-4
|
|
)
|
|
|
|
def test_mixed_valid_invalid_multi_index(self):
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 256
|
|
K = 3
|
|
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
|
|
index = torch.tensor(
|
|
[
|
|
[0, 10, -1], # last invalid
|
|
[V, 5, 100], # first invalid
|
|
[50, 60, 70], # all valid
|
|
[-1, V + 1, -100], # all invalid
|
|
],
|
|
device="cuda",
|
|
dtype=torch.long,
|
|
)
|
|
result = selective_log_softmax(logits, index)
|
|
assert result.shape == (4, K)
|
|
# Row 2 (all valid) must match reference exactly
|
|
valid_index = torch.tensor([[50, 60, 70]], device="cuda", dtype=torch.long)
|
|
expected = _ref_selective_log_softmax(logits[2:3], valid_index)
|
|
torch.testing.assert_close(result[2:3], expected, atol=1e-4, rtol=1e-4)
|
|
|
|
def test_oob_backward_no_crash(self):
|
|
"""Backward with OOB indices should not crash and grads should be finite."""
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(
|
|
4, V, device="cuda", dtype=torch.float32, requires_grad=True
|
|
)
|
|
index = torch.tensor(
|
|
[[-1], [0], [V + 10], [V - 1]], device="cuda", dtype=torch.long
|
|
)
|
|
result = selective_log_softmax(logits, index)
|
|
result.sum().backward()
|
|
assert logits.grad is not None
|
|
assert torch.isfinite(logits.grad).all()
|
|
|
|
def test_oob_backward_valid_rows_correct(self):
|
|
"""Gradients for valid-index rows should match reference even when other rows have OOB."""
|
|
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
|
|
|
V = 128
|
|
logits = torch.randn(
|
|
4, V, device="cuda", dtype=torch.float32, requires_grad=True
|
|
)
|
|
# Row 0: invalid, Row 1: valid, Row 2: invalid, Row 3: valid
|
|
index = torch.tensor(
|
|
[[-1], [42], [V + 5], [100]], device="cuda", dtype=torch.long
|
|
)
|
|
result = selective_log_softmax(logits, index)
|
|
result.sum().backward()
|
|
|
|
# Compute reference gradient for valid rows only
|
|
logits_ref = logits.detach().clone().requires_grad_(True)
|
|
valid_rows = [1, 3]
|
|
valid_indices = [42, 100]
|
|
for r, idx in zip(valid_rows, valid_indices, strict=True):
|
|
ref_lp = F.log_softmax(logits_ref[r : r + 1], dim=-1)
|
|
ref_val = ref_lp[0, idx]
|
|
ref_val.backward(retain_graph=True)
|
|
|
|
for r in valid_rows:
|
|
torch.testing.assert_close(
|
|
logits.grad[r], logits_ref.grad[r], atol=1e-4, rtol=1e-4
|
|
)
|