use custom triton kernels for entropy from logits and selective softmax (#3510)

* use custom triton kernels for entropy from logits and selective softmax

* PR comments fixes

* fix out of bounds, include tests, include benchmarks

* chore: lint
This commit is contained in:
Wing Lian
2026-03-19 02:02:43 -04:00
committed by GitHub
parent f291ac029c
commit 163bd4dd5a
6 changed files with 1346 additions and 0 deletions

208
benchmarks/bench_entropy.py Normal file
View File

@@ -0,0 +1,208 @@
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
"""
import gc
import statistics
import torch
import torch.nn.functional as F
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
"""Original chunked implementation (reference)."""
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, logits, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = fn(logits, chunk_size=128)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
del out
return times
def profile_memory(fn, logits, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(logits, chunk_size=128)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_contiguous():
print("=" * 60)
print(
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(1, 16384),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
t_orig = profile_time(entropy_from_logits_original, logits)
t_triton = profile_time(entropy_from_logits, logits)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(entropy_from_logits_original, logits)
m_triton = profile_memory(entropy_from_logits, logits)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits
_clean_gpu()
def benchmark_noncontiguous():
print("\n" + "=" * 60)
print(
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(4, 2048, "transpose"),
(4, 8192, "transpose"),
(8, 2048, "transpose"),
(4, 4096, "slice_batch"),
]
for B, L, method in configs:
torch.manual_seed(42)
if method == "transpose":
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw.transpose(0, 1)
raw_gb = L * B * V * 2 / 1e9
elif method == "slice_batch":
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw[::2]
raw_gb = B * 2 * L * V * 2 / 1e9
else:
continue
if raw_gb > 28:
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
del raw, logits_nc
torch.cuda.empty_cache()
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
print(f"{'' * 60}")
def original_with_copy(logits, chunk_size=128):
return entropy_from_logits_original(
logits.contiguous(), chunk_size=chunk_size
)
t_orig = profile_time(original_with_copy, logits_nc)
t_triton = profile_time(entropy_from_logits, logits_nc)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" orig+copy: {fmt(t_orig, 'ms')}")
print(f" triton-strided:{fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(original_with_copy, logits_nc)
m_triton = profile_memory(entropy_from_logits, logits_nc)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" orig+copy: {fmt(m_orig, 'MB')}")
print(f" triton-strided:{fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del raw, logits_nc
_clean_gpu()
if __name__ == "__main__":
benchmark_contiguous()
benchmark_noncontiguous()

View File

@@ -0,0 +1,191 @@
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
"""
import gc
import statistics
import torch
from axolotl.monkeypatch.trainer.utils import (
selective_log_softmax,
selective_log_softmax_original,
)
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, args, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
fn(*args)
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn(*args)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return times
def profile_memory(fn, args, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(*args)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(*args)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_forward():
print("=" * 60)
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(selective_log_softmax_original, (logits, index))
t_triton = profile_time(selective_log_softmax, (logits, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
m_triton = profile_memory(selective_log_softmax, (logits, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits, index
_clean_gpu()
def benchmark_backward():
print("\n" + "=" * 60)
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
def fwd_bwd_original(logits, index):
logits.grad = None
out = selective_log_softmax_original(logits, index)
out.sum().backward()
def fwd_bwd_triton(logits, index):
logits.grad = None
out = selective_log_softmax(logits, index)
out.sum().backward()
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 20:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits_orig = torch.randn(
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_orig.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" FWD+BWD TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" FWD+BWD MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits_orig, logits_tri, index
_clean_gpu()
if __name__ == "__main__":
benchmark_forward()
benchmark_backward()

View File

@@ -117,6 +117,7 @@ class PatchManager:
self._apply_voxtral_patches()
self._apply_apertus_patches()
self._apply_trl_vllm_patches()
self._apply_trl_trainer_utils_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -679,6 +680,39 @@ class PatchManager:
patch_trl_vllm()
def _apply_trl_trainer_utils_patches(self):
"""Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels."""
if not self.cfg.rl:
return
try:
from axolotl.monkeypatch.trainer.utils import (
entropy_from_logits,
selective_log_softmax,
)
except (ImportError, ModuleNotFoundError):
LOG.warning("Triton not available — skipping trl.trainer.utils patches")
return
import trl.trainer.utils
# Guard against repeated calls: only stash the original if trl still
# points at its own implementation (not our wrapper).
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils
_axolotl_trainer_utils.selective_log_softmax_original = (
trl.trainer.utils.selective_log_softmax
)
trl.trainer.utils.selective_log_softmax = selective_log_softmax
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
trl.trainer.utils.entropy_from_logits = entropy_from_logits
LOG.info(
"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits"
)
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:

View File

@@ -0,0 +1,3 @@
from .utils import entropy_from_logits, selective_log_softmax
__all__ = ["entropy_from_logits", "selective_log_softmax"]

View File

@@ -0,0 +1,429 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
@triton.jit
def _entropy_online_kernel(
logits_ptr,
output_ptr,
stride_row,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Online entropy: single pass with running max correction."""
row = tl.program_id(0)
row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
running_weighted = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
correction = tl.exp(running_max - new_max)
running_sum_exp = running_sum_exp * correction
running_weighted = running_weighted * correction
exp_x = tl.exp(x - new_max)
exp_x = tl.where(mask, exp_x, 0.0)
x = tl.where(mask, x, 0.0)
running_sum_exp += tl.sum(exp_x, axis=0)
running_weighted += tl.sum(exp_x * x, axis=0)
running_max = new_max
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
tl.store(output_ptr + row, entropy)
@triton.jit
def _entropy_online_kernel_strided(
logits_ptr,
output_ptr,
stride_outer,
stride_inner,
n_inner,
row_offset,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Online entropy for non-contiguous 3D (B, L, V) tensors."""
local_row = tl.program_id(0)
row = local_row + row_offset
outer_idx = row // n_inner
inner_idx = row % n_inner
off = outer_idx.to(tl.int64) * stride_outer + inner_idx.to(tl.int64) * stride_inner
row_ptr = logits_ptr + off
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
running_weighted = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
correction = tl.exp(running_max - new_max)
running_sum_exp = running_sum_exp * correction
running_weighted = running_weighted * correction
exp_x = tl.exp(x - new_max)
exp_x = tl.where(mask, exp_x, 0.0)
x = tl.where(mask, x, 0.0)
running_sum_exp += tl.sum(exp_x, axis=0)
running_weighted += tl.sum(exp_x * x, axis=0)
running_max = new_max
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
tl.store(output_ptr + local_row, entropy)
def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
"""Triton-fused entropy (online single-pass). Handles non-contiguous tensors without copying."""
original_shape = logits.shape[:-1]
V = logits.shape[-1]
N = 1
for s in original_shape:
N *= s
if not logits.is_cuda:
# CPU fallback: stable entropy via log_softmax
logp = F.log_softmax(logits.float(), dim=-1)
ent = -(logp.exp() * logp).sum(dim=-1)
return ent.to(logits.dtype).reshape(original_shape)
output = torch.empty(N, device=logits.device, dtype=torch.float32)
BLOCK_V = 4096
MAX_GRID_CONTIG = 8192
MAX_GRID_STRIDED = 2048
# Vocab (last) dim must be contiguous for coalesced loads
if logits.stride(-1) != 1:
logits = logits.contiguous()
if logits.is_contiguous():
flat_logits = logits.reshape(-1, V)
stride = flat_logits.stride(0)
for start in range(0, N, MAX_GRID_CONTIG):
n_rows = min(MAX_GRID_CONTIG, N - start)
_entropy_online_kernel[(n_rows,)](
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
)
elif logits.ndim == 3:
stride_outer = logits.stride(0)
stride_inner = logits.stride(1)
n_inner = logits.shape[1]
for start in range(0, N, MAX_GRID_STRIDED):
n_rows = min(MAX_GRID_STRIDED, N - start)
_entropy_online_kernel_strided[(n_rows,)](
logits,
output[start],
stride_outer,
stride_inner,
n_inner,
start,
V=V,
BLOCK_V=BLOCK_V,
)
else:
logits = logits.contiguous()
flat_logits = logits.reshape(-1, V)
stride = flat_logits.stride(0)
for start in range(0, N, MAX_GRID_CONTIG):
n_rows = min(MAX_GRID_CONTIG, N - start)
_entropy_online_kernel[(n_rows,)](
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
)
return output.to(logits.dtype).reshape(original_shape)
# ---------------------------------------------------------------------------
# selective_log_softmax — fused forward + backward Triton kernels
# ---------------------------------------------------------------------------
def selective_log_softmax_original(logits, index) -> torch.Tensor:
"""Original selective_log_softmax (reference/fallback)."""
squeeze = index.ndim == logits.ndim - 1
if squeeze:
index = index.unsqueeze(-1)
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1)
else:
per_token_logps = []
for row_logits, row_labels in zip(logits, index, strict=True):
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
if squeeze:
per_token_logps = per_token_logps.squeeze(-1)
return per_token_logps
@triton.jit
def _selective_logsoftmax_fwd_kernel(
logits_ptr,
index_ptr,
output_ptr,
logsumexp_ptr,
stride_logits_row,
stride_index_row,
stride_output_row,
actual_K,
K_BLOCK: tl.constexpr,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Forward: online logsumexp + gather. Saves logsumexp for backward."""
row = tl.program_id(0)
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
# Online logsumexp
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(logits_row_ptr + offs, mask=mask, other=float("-inf")).to(
tl.float32
)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)
exp_x = tl.exp(x - new_max)
exp_x = tl.where(mask, exp_x, 0.0)
running_sum_exp += tl.sum(exp_x, axis=0)
running_max = new_max
lse = tl.log(running_sum_exp) + running_max
tl.store(logsumexp_ptr + row, lse)
# Gather and subtract
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
output_row_ptr = output_ptr + tl.cast(row, tl.int64) * stride_output_row
k_offs = tl.arange(0, K_BLOCK)
k_mask = k_offs < actual_K
indices = tl.load(index_row_ptr + k_offs, mask=k_mask, other=0).to(tl.int64)
valid_mask = k_mask & (indices >= 0) & (indices < V)
safe_indices = tl.where(valid_mask, indices, 0)
selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask, other=0.0).to(
tl.float32
)
tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)
@triton.jit
def _selective_logsoftmax_bwd_kernel(
grad_output_ptr,
logits_ptr,
index_ptr,
logsumexp_ptr,
grad_logits_ptr,
stride_grad_out_row,
stride_logits_row,
stride_index_row,
stride_grad_logits_row,
actual_K,
K_BLOCK: tl.constexpr,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Backward: d_logits[j] = -softmax(x)[j] * sum(grad_out) + (grad_out[k] if j == index[k]).
Single fused pass over V. For each tile, computes the base gradient and adds
scatter contributions inline by checking which indices fall in the current tile.
No separate scatter pass — no read-after-write issues.
"""
row = tl.program_id(0)
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
grad_logits_row_ptr = (
grad_logits_ptr + tl.cast(row, tl.int64) * stride_grad_logits_row
)
grad_out_row_ptr = grad_output_ptr + tl.cast(row, tl.int64) * stride_grad_out_row
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
lse = tl.load(logsumexp_ptr + row).to(tl.float32)
# Load grad_output and indices (K_BLOCK elements, masked)
k_offs = tl.arange(0, K_BLOCK)
k_mask = k_offs < actual_K
grad_out = tl.load(grad_out_row_ptr + k_offs, mask=k_mask, other=0.0).to(tl.float32)
indices = tl.load(
index_row_ptr + k_offs, mask=k_mask, other=-1
) # -1 = never matches
valid_mask = k_mask & (indices >= 0) & (indices < V)
grad_out = tl.where(valid_mask, grad_out, 0.0)
indices = tl.where(valid_mask, indices, -1)
grad_sum = tl.sum(grad_out, axis=0)
# Fused pass: for each tile, compute -softmax * grad_sum + scatter
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V) # [BLOCK_V]
mask = offs < V
x = tl.load(logits_row_ptr + offs, mask=mask, other=0.0).to(tl.float32)
softmax_j = tl.exp(x - lse)
softmax_j = tl.where(mask, softmax_j, 0.0)
grad_j = -softmax_j * grad_sum
# Scatter: check which selected indices fall in this tile
# offs: [BLOCK_V], indices: [K_BLOCK]
# Broadcast: offs[:, None] == indices[None, :] → [BLOCK_V, K_BLOCK]
match = offs[:, None] == indices[None, :] # [BLOCK_V, K_BLOCK]
# Sum grad_out contributions: for each position j, sum grad_out[k] where index[k]==j
scatter_contrib = tl.sum(
tl.where(match, grad_out[None, :], 0.0), axis=1
) # [BLOCK_V]
grad_j += scatter_contrib
tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)
class _SelectiveLogSoftmaxTriton(torch.autograd.Function):
@staticmethod
def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID):
N = flat_logits.shape[0]
output = torch.empty(N, K_BLOCK, device=flat_logits.device, dtype=torch.float32)
logsumexp = torch.empty(N, device=flat_logits.device, dtype=torch.float32)
for start in range(0, N, MAX_GRID):
n_rows = min(MAX_GRID, N - start)
_selective_logsoftmax_fwd_kernel[(n_rows,)](
flat_logits[start],
flat_index[start],
output[start],
logsumexp[start],
flat_logits.stride(0),
flat_index.stride(0),
output.stride(0),
K,
K_BLOCK=K_BLOCK,
V=V,
BLOCK_V=BLOCK_V,
)
ctx.save_for_backward(flat_logits, flat_index, logsumexp)
ctx.K = K
ctx.K_BLOCK = K_BLOCK
ctx.V = V
ctx.BLOCK_V = BLOCK_V
ctx.MAX_GRID = MAX_GRID
return output
@staticmethod
def backward(ctx, grad_output):
flat_logits, flat_index, logsumexp = ctx.saved_tensors
K, K_BLOCK, V, BLOCK_V, MAX_GRID = (
ctx.K,
ctx.K_BLOCK,
ctx.V,
ctx.BLOCK_V,
ctx.MAX_GRID,
)
N = flat_logits.shape[0]
grad_logits = torch.empty_like(flat_logits)
# grad_output may have K_BLOCK cols; backward kernel reads actual_K
grad_output_contig = grad_output.contiguous()
for start in range(0, N, MAX_GRID):
n_rows = min(MAX_GRID, N - start)
_selective_logsoftmax_bwd_kernel[(n_rows,)](
grad_output_contig[start],
flat_logits[start],
flat_index[start],
logsumexp[start],
grad_logits[start],
grad_output_contig.stride(0),
flat_logits.stride(0),
flat_index.stride(0),
grad_logits.stride(0),
K,
K_BLOCK=K_BLOCK,
V=V,
BLOCK_V=BLOCK_V,
)
# Return grads for: flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
return grad_logits, None, None, None, None, None, None
def selective_log_softmax(logits, index) -> torch.Tensor:
"""
Fused selective_log_softmax with Triton forward+backward kernels.
Equivalent to: torch.gather(logits.log_softmax(-1), dim=-1, index=index)
"""
squeeze = index.ndim == logits.ndim - 1
if squeeze:
index = index.unsqueeze(-1)
if not logits.is_cuda or logits.dtype == torch.float64:
# Triton kernel computes in float32; fall back for float64 and CPU
return selective_log_softmax_original(
logits, index.squeeze(-1) if squeeze else index
)
V = logits.shape[-1]
K = index.shape[-1]
original_index_shape = index.shape
flat_logits = logits.reshape(-1, V).contiguous()
flat_index = index.reshape(-1, K).contiguous()
BLOCK_V = 4096
MAX_GRID = 8192
K_BLOCK = max(1, triton.next_power_of_2(K))
output = _SelectiveLogSoftmaxTriton.apply(
flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
)
if K_BLOCK != K:
output = output[:, :K]
per_token_logps = output.to(logits.dtype).reshape(original_index_shape)
if squeeze:
per_token_logps = per_token_logps.squeeze(-1)
return per_token_logps

View File

@@ -0,0 +1,481 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
"""Unit tests for Triton kernels: entropy_from_logits and selective_log_softmax.
Adapted from harness/test_entropy.py and harness/test_selective_logsoftmax.py
into proper pytest tests, plus new OOB index safety tests.
"""
import math
import pytest
import torch
import torch.nn.functional as F
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
)
# ---------------------------------------------------------------------------
# Reference implementations
# ---------------------------------------------------------------------------
def _ref_entropy(logits):
"""Reference entropy via log_softmax (numerically stable)."""
logp = F.log_softmax(logits.float(), dim=-1)
return -(logp.exp() * logp).sum(dim=-1)
def _ref_selective_log_softmax(logits, index):
"""Reference selective log softmax via PyTorch gather."""
squeeze = index.ndim == logits.ndim - 1
if squeeze:
index = index.unsqueeze(-1)
log_probs = F.log_softmax(logits.float(), dim=-1)
result = torch.gather(log_probs, dim=-1, index=index)
if squeeze:
result = result.squeeze(-1)
return result
# ---------------------------------------------------------------------------
# entropy_from_logits
# ---------------------------------------------------------------------------
class TestEntropyFromLogits:
@pytest.mark.parametrize(
"B,L",
[
(1, 128),
(1, 2048),
(4, 512),
(8, 256),
(1, 1),
],
)
def test_correctness_various_shapes(self, B, L):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 1024
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
assert result.shape == (B, L)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_2d_input(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(16, 256, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
assert result.shape == (16,)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_large_vocab(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 32000
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_uniform_distribution(self):
"""Uniform logits -> entropy = log(V)."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 1024
logits = torch.zeros(2, V, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected_val = math.log(V)
torch.testing.assert_close(
result,
torch.full((2,), expected_val, device="cuda", dtype=torch.float32),
atol=1e-4,
rtol=1e-4,
)
def test_peaked_distribution(self):
"""One-hot-like logits -> entropy near 0."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.full((2, 128), -100.0, device="cuda", dtype=torch.float32)
logits[:, 0] = 100.0
result = entropy_from_logits(logits)
assert (result < 1e-3).all()
def test_bfloat16(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(4, 256, device="cuda", dtype=torch.bfloat16)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits.float())
assert result.dtype == torch.bfloat16
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
def test_float16(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(4, 256, device="cuda", dtype=torch.float16)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits.float())
assert result.dtype == torch.float16
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
def test_non_contiguous_3d_transpose(self):
"""Non-contiguous 3D tensor via transpose(0,1)."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 256
raw = torch.randn(32, 4, V, device="cuda", dtype=torch.float32)
logits = raw.transpose(0, 1) # (4, 32, V) non-contiguous
assert not logits.is_contiguous()
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_non_contiguous_3d_slice(self):
"""Non-contiguous 3D tensor via batch slicing."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 256
raw = torch.randn(8, 32, V, device="cuda", dtype=torch.float32)
logits = raw[::2] # (4, 32, V) non-contiguous
assert not logits.is_contiguous()
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_many_rows_beyond_max_grid(self):
"""More rows than MAX_GRID (8192) to test chunked dispatch."""
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(10000, 128, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
expected = _ref_entropy(logits)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_entropy_non_negative(self):
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
logits = torch.randn(32, 512, device="cuda", dtype=torch.float32)
result = entropy_from_logits(logits)
assert (result >= -1e-5).all(), f"Negative entropy: {result.min()}"
# ---------------------------------------------------------------------------
# selective_log_softmax — forward correctness
# ---------------------------------------------------------------------------
class TestSelectiveLogSoftmax:
@pytest.mark.parametrize(
"B,L,K",
[
(1, 128, 1),
(4, 512, 1),
(8, 256, 1),
(4, 256, 4),
(4, 256, 7),
(15, 129, 1), # non-power-of-2
],
)
def test_correctness_various_shapes(self, B, L, K):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 1024
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
if K == 1:
index = torch.randint(0, V, (B, L), device="cuda")
else:
index = torch.randint(0, V, (B, L, K), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_squeezed_index(self):
"""Index with ndim == logits.ndim - 1 triggers squeeze path."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (8,), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
assert result.shape == (8,)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_large_vocab(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 32000
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (2, 1), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_bfloat16(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 1024
torch.manual_seed(42)
logits = torch.randn(4, 128, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (4, 128), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits.float(), index)
assert result.dtype == torch.bfloat16
torch.testing.assert_close(result.float(), expected, atol=0.1, rtol=0.1)
def test_fp32_tight_tolerance(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 1024
torch.manual_seed(42)
logits = torch.randn(2, 256, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (2, 256), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
def test_all_same_index(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
index = torch.zeros(8, 1, device="cuda", dtype=torch.long)
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_last_index(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
index = torch.full((8, 1), V - 1, device="cuda", dtype=torch.long)
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
def test_output_always_nonpositive(self):
"""Log softmax values should always be <= 0."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
logits = torch.randn(32, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (32, 1), device="cuda")
result = selective_log_softmax(logits, index)
assert (result <= 1e-5).all(), f"Positive log-prob: {result.max()}"
def test_many_rows_beyond_max_grid(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(10000, V, device="cuda", dtype=torch.float32)
index = torch.randint(0, V, (10000, 1), device="cuda")
result = selective_log_softmax(logits, index)
expected = _ref_selective_log_softmax(logits, index)
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
# ---------------------------------------------------------------------------
# selective_log_softmax — backward / gradient correctness
# ---------------------------------------------------------------------------
class TestSelectiveLogSoftmaxBackward:
@pytest.mark.parametrize(
"B,L,V,K",
[
(2, 16, 64, 1),
(2, 16, 64, 4),
(1, 8, 128, 1),
(2, 8, 128, 7),
],
)
def test_gradient_matches_reference(self, B, L, V, K):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
torch.manual_seed(42)
logits_ref = torch.randn(
B, L, V, device="cuda", dtype=torch.float32, requires_grad=True
)
logits_tri = logits_ref.detach().clone().requires_grad_(True)
if K == 1:
index = torch.randint(0, V, (B, L), device="cuda")
else:
index = torch.randint(0, V, (B, L, K), device="cuda")
ref_out = _ref_selective_log_softmax(logits_ref, index)
tri_out = selective_log_softmax(logits_tri, index)
ref_out.sum().backward()
tri_out.sum().backward()
torch.testing.assert_close(
logits_tri.grad, logits_ref.grad, atol=1e-5, rtol=1e-5
)
def test_gradient_bfloat16_full_vocab(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 4096
torch.manual_seed(42)
logits_ref = torch.randn(
2, 64, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_ref.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (2, 64), device="cuda")
_ref_selective_log_softmax(logits_ref, index).sum().backward()
selective_log_softmax(logits_tri, index).sum().backward()
torch.testing.assert_close(
logits_tri.grad.float(), logits_ref.grad.float(), atol=0.1, rtol=0.1
)
def test_gradient_k1_squeezed(self):
"""Gradient with squeezed (1D) index."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
logits = torch.randn(
8, V, device="cuda", dtype=torch.float32, requires_grad=True
)
index = torch.randint(0, V, (8,), device="cuda")
result = selective_log_softmax(logits, index)
result.sum().backward()
triton_grad = logits.grad.clone()
logits.grad = None
ref = torch.gather(
F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
ref.sum().backward()
torch.testing.assert_close(triton_grad, logits.grad, atol=1e-4, rtol=1e-4)
# ---------------------------------------------------------------------------
# selective_log_softmax — out-of-bounds index safety
# ---------------------------------------------------------------------------
class TestSelectiveLogSoftmaxOOBSafety:
"""Verify that out-of-range indices don't crash or corrupt valid results."""
def test_negative_indices_no_crash(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
index = torch.tensor(
[[-1], [0], [V - 1], [-5]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
assert result.shape == (4, 1)
# Valid rows should be finite and match reference
valid_idx = torch.tensor([[0], [V - 1]], device="cuda", dtype=torch.long)
valid_logits = logits[1:3]
expected = _ref_selective_log_softmax(valid_logits, valid_idx)
torch.testing.assert_close(result[1:3], expected, atol=1e-4, rtol=1e-4)
def test_index_exceeds_vocab_no_crash(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
index = torch.tensor(
[[0], [V], [V + 100], [V - 1]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
assert result.shape == (4, 1)
# Valid rows (0 and 3) should match reference
for row_idx, idx_val in [(0, 0), (3, V - 1)]:
ref = _ref_selective_log_softmax(
logits[row_idx : row_idx + 1],
torch.tensor([[idx_val]], device="cuda", dtype=torch.long),
)
torch.testing.assert_close(
result[row_idx : row_idx + 1], ref, atol=1e-4, rtol=1e-4
)
def test_mixed_valid_invalid_multi_index(self):
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 256
K = 3
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
index = torch.tensor(
[
[0, 10, -1], # last invalid
[V, 5, 100], # first invalid
[50, 60, 70], # all valid
[-1, V + 1, -100], # all invalid
],
device="cuda",
dtype=torch.long,
)
result = selective_log_softmax(logits, index)
assert result.shape == (4, K)
# Row 2 (all valid) must match reference exactly
valid_index = torch.tensor([[50, 60, 70]], device="cuda", dtype=torch.long)
expected = _ref_selective_log_softmax(logits[2:3], valid_index)
torch.testing.assert_close(result[2:3], expected, atol=1e-4, rtol=1e-4)
def test_oob_backward_no_crash(self):
"""Backward with OOB indices should not crash and grads should be finite."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(
4, V, device="cuda", dtype=torch.float32, requires_grad=True
)
index = torch.tensor(
[[-1], [0], [V + 10], [V - 1]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
result.sum().backward()
assert logits.grad is not None
assert torch.isfinite(logits.grad).all()
def test_oob_backward_valid_rows_correct(self):
"""Gradients for valid-index rows should match reference even when other rows have OOB."""
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
V = 128
logits = torch.randn(
4, V, device="cuda", dtype=torch.float32, requires_grad=True
)
# Row 0: invalid, Row 1: valid, Row 2: invalid, Row 3: valid
index = torch.tensor(
[[-1], [42], [V + 5], [100]], device="cuda", dtype=torch.long
)
result = selective_log_softmax(logits, index)
result.sum().backward()
# Compute reference gradient for valid rows only
logits_ref = logits.detach().clone().requires_grad_(True)
valid_rows = [1, 3]
valid_indices = [42, 100]
for r, idx in zip(valid_rows, valid_indices, strict=True):
ref_lp = F.log_softmax(logits_ref[r : r + 1], dim=-1)
ref_val = ref_lp[0, idx]
ref_val.backward(retain_graph=True)
for r in valid_rows:
torch.testing.assert_close(
logits.grad[r], logits_ref.grad[r], atol=1e-4, rtol=1e-4
)