From 163bd4dd5a9dc6097d923cdd733cf5d1593c056c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 19 Mar 2026 02:02:43 -0400 Subject: [PATCH] use custom triton kernels for entropy from logits and selective softmax (#3510) * use custom triton kernels for entropy from logits and selective softmax * PR comments fixes * fix out of bounds, include tests, include benchmarks * chore: lint --- benchmarks/bench_entropy.py | 208 +++++++++ benchmarks/bench_selective_logsoftmax.py | 191 ++++++++ src/axolotl/loaders/patch_manager.py | 34 ++ src/axolotl/monkeypatch/trainer/__init__.py | 3 + src/axolotl/monkeypatch/trainer/utils.py | 429 +++++++++++++++++ tests/test_triton_kernels.py | 481 ++++++++++++++++++++ 6 files changed, 1346 insertions(+) create mode 100644 benchmarks/bench_entropy.py create mode 100644 benchmarks/bench_selective_logsoftmax.py create mode 100644 src/axolotl/monkeypatch/trainer/utils.py create mode 100644 tests/test_triton_kernels.py diff --git a/benchmarks/bench_entropy.py b/benchmarks/bench_entropy.py new file mode 100644 index 000000000..95c7291b3 --- /dev/null +++ b/benchmarks/bench_entropy.py @@ -0,0 +1,208 @@ +"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation. + +Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py +""" + +import gc +import statistics + +import torch +import torch.nn.functional as F + +from axolotl.monkeypatch.trainer.utils import entropy_from_logits + +V = 151936 # Qwen vocab +WARMUP = 5 +BENCH_ITERS = 20 +MEM_ITERS = 10 + + +def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128): + """Original chunked implementation (reference).""" + original_shape = logits.shape[:-1] + num_classes = logits.shape[-1] + flat_logits = logits.reshape(-1, num_classes) + entropies = [] + for chunk in flat_logits.split(chunk_size, dim=0): + logps = F.log_softmax(chunk, dim=-1) + chunk_entropy = -(torch.exp(logps) * logps).sum(-1) + entropies.append(chunk_entropy) + return torch.cat(entropies, dim=0).reshape(original_shape) + + +def _clean_gpu(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + torch.cuda.synchronize() + + +def profile_time(fn, logits, n_iters=BENCH_ITERS): + for _ in range(WARMUP): + out = fn(logits, chunk_size=128) + del out + torch.cuda.synchronize() + + times = [] + for _ in range(n_iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + out = fn(logits, chunk_size=128) + e.record() + torch.cuda.synchronize() + times.append(s.elapsed_time(e)) + del out + return times + + +def profile_memory(fn, logits, n_iters=MEM_ITERS): + for _ in range(WARMUP): + out = fn(logits, chunk_size=128) + del out + torch.cuda.synchronize() + + peaks = [] + for _ in range(n_iters): + _clean_gpu() + base = torch.cuda.max_memory_allocated() + out = fn(logits, chunk_size=128) + torch.cuda.synchronize() + peaks.append(torch.cuda.max_memory_allocated() - base) + del out + return [p / 1e6 for p in peaks] + + +def fmt(values, unit=""): + mean = statistics.mean(values) + std = statistics.stdev(values) if len(values) > 1 else 0.0 + return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]" + + +def benchmark_contiguous(): + print("=" * 60) + print( + f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})" + ) + print("=" * 60) + + configs = [ + (1, 2048), + (1, 8192), + (1, 16384), + (4, 4096), + (8, 2048), + (16, 2048), + (16, 4096), + ] + + for B, L in configs: + mem_gb = B * L * V * 2 / 1e9 + if mem_gb > 28: + print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)") + continue + + N = B * L + print(f"\n{'─' * 60}") + print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)") + print(f"{'─' * 60}") + + torch.manual_seed(42) + logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16) + + t_orig = profile_time(entropy_from_logits_original, logits) + t_triton = profile_time(entropy_from_logits, logits) + orig_mean = statistics.mean(t_orig) + triton_mean = statistics.mean(t_triton) + + print(" TIME (ms):") + print(f" original: {fmt(t_orig, 'ms')}") + print(f" triton: {fmt(t_triton, 'ms')}") + print(f" speedup: {orig_mean / triton_mean:.2f}x") + + m_orig = profile_memory(entropy_from_logits_original, logits) + m_triton = profile_memory(entropy_from_logits, logits) + orig_peak = statistics.mean(m_orig) + triton_peak = statistics.mean(m_triton) + + print(" MEMORY (peak overhead):") + print(f" original: {fmt(m_orig, 'MB')}") + print(f" triton: {fmt(m_triton, 'MB')}") + print(f" saved: {orig_peak - triton_peak:.1f} MB") + + del logits + _clean_gpu() + + +def benchmark_noncontiguous(): + print("\n" + "=" * 60) + print( + f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})" + ) + print("=" * 60) + + configs = [ + (4, 2048, "transpose"), + (4, 8192, "transpose"), + (8, 2048, "transpose"), + (4, 4096, "slice_batch"), + ] + + for B, L, method in configs: + torch.manual_seed(42) + + if method == "transpose": + raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16) + logits_nc = raw.transpose(0, 1) + raw_gb = L * B * V * 2 / 1e9 + elif method == "slice_batch": + raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16) + logits_nc = raw[::2] + raw_gb = B * 2 * L * V * 2 / 1e9 + else: + continue + + if raw_gb > 28: + print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)") + del raw, logits_nc + torch.cuda.empty_cache() + continue + + N = B * L + print(f"\n{'─' * 60}") + print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)") + print(f"{'─' * 60}") + + def original_with_copy(logits, chunk_size=128): + return entropy_from_logits_original( + logits.contiguous(), chunk_size=chunk_size + ) + + t_orig = profile_time(original_with_copy, logits_nc) + t_triton = profile_time(entropy_from_logits, logits_nc) + orig_mean = statistics.mean(t_orig) + triton_mean = statistics.mean(t_triton) + + print(" TIME (ms):") + print(f" orig+copy: {fmt(t_orig, 'ms')}") + print(f" triton-strided:{fmt(t_triton, 'ms')}") + print(f" speedup: {orig_mean / triton_mean:.2f}x") + + m_orig = profile_memory(original_with_copy, logits_nc) + m_triton = profile_memory(entropy_from_logits, logits_nc) + orig_peak = statistics.mean(m_orig) + triton_peak = statistics.mean(m_triton) + + print(" MEMORY (peak overhead):") + print(f" orig+copy: {fmt(m_orig, 'MB')}") + print(f" triton-strided:{fmt(m_triton, 'MB')}") + print(f" saved: {orig_peak - triton_peak:.1f} MB") + + del raw, logits_nc + _clean_gpu() + + +if __name__ == "__main__": + benchmark_contiguous() + benchmark_noncontiguous() diff --git a/benchmarks/bench_selective_logsoftmax.py b/benchmarks/bench_selective_logsoftmax.py new file mode 100644 index 000000000..b8f517e4b --- /dev/null +++ b/benchmarks/bench_selective_logsoftmax.py @@ -0,0 +1,191 @@ +"""Benchmark for selective_log_softmax Triton kernel vs original implementation. + +Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py +""" + +import gc +import statistics + +import torch + +from axolotl.monkeypatch.trainer.utils import ( + selective_log_softmax, + selective_log_softmax_original, +) + +V = 151936 # Qwen vocab +WARMUP = 5 +BENCH_ITERS = 20 +MEM_ITERS = 10 + + +def _clean_gpu(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + torch.cuda.synchronize() + + +def profile_time(fn, args, n_iters=BENCH_ITERS): + for _ in range(WARMUP): + fn(*args) + torch.cuda.synchronize() + + times = [] + for _ in range(n_iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + fn(*args) + e.record() + torch.cuda.synchronize() + times.append(s.elapsed_time(e)) + return times + + +def profile_memory(fn, args, n_iters=MEM_ITERS): + for _ in range(WARMUP): + out = fn(*args) + del out + torch.cuda.synchronize() + + peaks = [] + for _ in range(n_iters): + _clean_gpu() + base = torch.cuda.max_memory_allocated() + out = fn(*args) + torch.cuda.synchronize() + peaks.append(torch.cuda.max_memory_allocated() - base) + del out + return [p / 1e6 for p in peaks] + + +def fmt(values, unit=""): + mean = statistics.mean(values) + std = statistics.stdev(values) if len(values) > 1 else 0.0 + return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]" + + +def benchmark_forward(): + print("=" * 60) + print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})") + print("=" * 60) + + configs = [ + (1, 2048), + (1, 8192), + (4, 4096), + (8, 2048), + (16, 2048), + (16, 4096), + ] + + for B, L in configs: + mem_gb = B * L * V * 2 / 1e9 + if mem_gb > 28: + print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)") + continue + + N = B * L + print(f"\n{'─' * 60}") + print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)") + print(f"{'─' * 60}") + + torch.manual_seed(42) + logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16) + index = torch.randint(0, V, (B, L), device="cuda") + + t_orig = profile_time(selective_log_softmax_original, (logits, index)) + t_triton = profile_time(selective_log_softmax, (logits, index)) + orig_mean = statistics.mean(t_orig) + triton_mean = statistics.mean(t_triton) + + print(" TIME (ms):") + print(f" original: {fmt(t_orig, 'ms')}") + print(f" triton: {fmt(t_triton, 'ms')}") + print(f" speedup: {orig_mean / triton_mean:.2f}x") + + m_orig = profile_memory(selective_log_softmax_original, (logits, index)) + m_triton = profile_memory(selective_log_softmax, (logits, index)) + orig_peak = statistics.mean(m_orig) + triton_peak = statistics.mean(m_triton) + + print(" MEMORY (peak overhead):") + print(f" original: {fmt(m_orig, 'MB')}") + print(f" triton: {fmt(m_triton, 'MB')}") + print(f" saved: {orig_peak - triton_peak:.1f} MB") + + del logits, index + _clean_gpu() + + +def benchmark_backward(): + print("\n" + "=" * 60) + print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})") + print("=" * 60) + + configs = [ + (1, 2048), + (1, 8192), + (4, 4096), + (8, 2048), + (16, 2048), + (16, 4096), + ] + + def fwd_bwd_original(logits, index): + logits.grad = None + out = selective_log_softmax_original(logits, index) + out.sum().backward() + + def fwd_bwd_triton(logits, index): + logits.grad = None + out = selective_log_softmax(logits, index) + out.sum().backward() + + for B, L in configs: + mem_gb = B * L * V * 2 / 1e9 + if mem_gb > 20: + print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)") + continue + + N = B * L + print(f"\n{'─' * 60}") + print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)") + print(f"{'─' * 60}") + + torch.manual_seed(42) + logits_orig = torch.randn( + B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + logits_tri = logits_orig.detach().clone().requires_grad_(True) + index = torch.randint(0, V, (B, L), device="cuda") + + t_orig = profile_time(fwd_bwd_original, (logits_orig, index)) + t_triton = profile_time(fwd_bwd_triton, (logits_tri, index)) + orig_mean = statistics.mean(t_orig) + triton_mean = statistics.mean(t_triton) + + print(" FWD+BWD TIME (ms):") + print(f" original: {fmt(t_orig, 'ms')}") + print(f" triton: {fmt(t_triton, 'ms')}") + print(f" speedup: {orig_mean / triton_mean:.2f}x") + + m_orig = profile_memory(fwd_bwd_original, (logits_orig, index)) + m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index)) + orig_peak = statistics.mean(m_orig) + triton_peak = statistics.mean(m_triton) + + print(" FWD+BWD MEMORY (peak overhead):") + print(f" original: {fmt(m_orig, 'MB')}") + print(f" triton: {fmt(m_triton, 'MB')}") + print(f" saved: {orig_peak - triton_peak:.1f} MB") + + del logits_orig, logits_tri, index + _clean_gpu() + + +if __name__ == "__main__": + benchmark_forward() + benchmark_backward() diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 205e32e6f..bddd388e4 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -117,6 +117,7 @@ class PatchManager: self._apply_voxtral_patches() self._apply_apertus_patches() self._apply_trl_vllm_patches() + self._apply_trl_trainer_utils_patches() def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" @@ -679,6 +680,39 @@ class PatchManager: patch_trl_vllm() + def _apply_trl_trainer_utils_patches(self): + """Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels.""" + if not self.cfg.rl: + return + + try: + from axolotl.monkeypatch.trainer.utils import ( + entropy_from_logits, + selective_log_softmax, + ) + except (ImportError, ModuleNotFoundError): + LOG.warning("Triton not available — skipping trl.trainer.utils patches") + return + + import trl.trainer.utils + + # Guard against repeated calls: only stash the original if trl still + # points at its own implementation (not our wrapper). + if trl.trainer.utils.selective_log_softmax is not selective_log_softmax: + from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils + + _axolotl_trainer_utils.selective_log_softmax_original = ( + trl.trainer.utils.selective_log_softmax + ) + trl.trainer.utils.selective_log_softmax = selective_log_softmax + + if trl.trainer.utils.entropy_from_logits is not entropy_from_logits: + trl.trainer.utils.entropy_from_logits = entropy_from_logits + + LOG.info( + "Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits" + ) + def _apply_scaling_softmax_patch(self, model: PreTrainedModel): """Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399""" if self.cfg.scaling_softmax: diff --git a/src/axolotl/monkeypatch/trainer/__init__.py b/src/axolotl/monkeypatch/trainer/__init__.py index e69de29bb..27edc63c3 100644 --- a/src/axolotl/monkeypatch/trainer/__init__.py +++ b/src/axolotl/monkeypatch/trainer/__init__.py @@ -0,0 +1,3 @@ +from .utils import entropy_from_logits, selective_log_softmax + +__all__ = ["entropy_from_logits", "selective_log_softmax"] diff --git a/src/axolotl/monkeypatch/trainer/utils.py b/src/axolotl/monkeypatch/trainer/utils.py new file mode 100644 index 000000000..467f50a5a --- /dev/null +++ b/src/axolotl/monkeypatch/trainer/utils.py @@ -0,0 +1,429 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + + +@triton.jit +def _entropy_online_kernel( + logits_ptr, + output_ptr, + stride_row, + V: tl.constexpr, + BLOCK_V: tl.constexpr, +): + """Online entropy: single pass with running max correction.""" + row = tl.program_id(0) + row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row + + running_max = tl.full([], float("-inf"), dtype=tl.float32) + running_sum_exp = tl.full([], 0.0, dtype=tl.float32) + running_weighted = tl.full([], 0.0, dtype=tl.float32) + + for v_start in range(0, V, BLOCK_V): + offs = v_start + tl.arange(0, BLOCK_V) + mask = offs < V + x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32) + + block_max = tl.max(x, axis=0) + new_max = tl.maximum(running_max, block_max) + + correction = tl.exp(running_max - new_max) + running_sum_exp = running_sum_exp * correction + running_weighted = running_weighted * correction + + exp_x = tl.exp(x - new_max) + exp_x = tl.where(mask, exp_x, 0.0) + x = tl.where(mask, x, 0.0) + running_sum_exp += tl.sum(exp_x, axis=0) + running_weighted += tl.sum(exp_x * x, axis=0) + + running_max = new_max + + entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp + tl.store(output_ptr + row, entropy) + + +@triton.jit +def _entropy_online_kernel_strided( + logits_ptr, + output_ptr, + stride_outer, + stride_inner, + n_inner, + row_offset, + V: tl.constexpr, + BLOCK_V: tl.constexpr, +): + """Online entropy for non-contiguous 3D (B, L, V) tensors.""" + local_row = tl.program_id(0) + row = local_row + row_offset + outer_idx = row // n_inner + inner_idx = row % n_inner + off = outer_idx.to(tl.int64) * stride_outer + inner_idx.to(tl.int64) * stride_inner + row_ptr = logits_ptr + off + + running_max = tl.full([], float("-inf"), dtype=tl.float32) + running_sum_exp = tl.full([], 0.0, dtype=tl.float32) + running_weighted = tl.full([], 0.0, dtype=tl.float32) + + for v_start in range(0, V, BLOCK_V): + offs = v_start + tl.arange(0, BLOCK_V) + mask = offs < V + x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32) + + block_max = tl.max(x, axis=0) + new_max = tl.maximum(running_max, block_max) + + correction = tl.exp(running_max - new_max) + running_sum_exp = running_sum_exp * correction + running_weighted = running_weighted * correction + + exp_x = tl.exp(x - new_max) + exp_x = tl.where(mask, exp_x, 0.0) + x = tl.where(mask, x, 0.0) + running_sum_exp += tl.sum(exp_x, axis=0) + running_weighted += tl.sum(exp_x * x, axis=0) + + running_max = new_max + + entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp + tl.store(output_ptr + local_row, entropy) + + +def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: + """Triton-fused entropy (online single-pass). Handles non-contiguous tensors without copying.""" + original_shape = logits.shape[:-1] + V = logits.shape[-1] + N = 1 + for s in original_shape: + N *= s + + if not logits.is_cuda: + # CPU fallback: stable entropy via log_softmax + logp = F.log_softmax(logits.float(), dim=-1) + ent = -(logp.exp() * logp).sum(dim=-1) + return ent.to(logits.dtype).reshape(original_shape) + + output = torch.empty(N, device=logits.device, dtype=torch.float32) + + BLOCK_V = 4096 + MAX_GRID_CONTIG = 8192 + MAX_GRID_STRIDED = 2048 + + # Vocab (last) dim must be contiguous for coalesced loads + if logits.stride(-1) != 1: + logits = logits.contiguous() + + if logits.is_contiguous(): + flat_logits = logits.reshape(-1, V) + stride = flat_logits.stride(0) + for start in range(0, N, MAX_GRID_CONTIG): + n_rows = min(MAX_GRID_CONTIG, N - start) + _entropy_online_kernel[(n_rows,)]( + flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V + ) + elif logits.ndim == 3: + stride_outer = logits.stride(0) + stride_inner = logits.stride(1) + n_inner = logits.shape[1] + for start in range(0, N, MAX_GRID_STRIDED): + n_rows = min(MAX_GRID_STRIDED, N - start) + _entropy_online_kernel_strided[(n_rows,)]( + logits, + output[start], + stride_outer, + stride_inner, + n_inner, + start, + V=V, + BLOCK_V=BLOCK_V, + ) + else: + logits = logits.contiguous() + flat_logits = logits.reshape(-1, V) + stride = flat_logits.stride(0) + for start in range(0, N, MAX_GRID_CONTIG): + n_rows = min(MAX_GRID_CONTIG, N - start) + _entropy_online_kernel[(n_rows,)]( + flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V + ) + + return output.to(logits.dtype).reshape(original_shape) + + +# --------------------------------------------------------------------------- +# selective_log_softmax — fused forward + backward Triton kernels +# --------------------------------------------------------------------------- + + +def selective_log_softmax_original(logits, index) -> torch.Tensor: + """Original selective_log_softmax (reference/fallback).""" + squeeze = index.ndim == logits.ndim - 1 + if squeeze: + index = index.unsqueeze(-1) + + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=index) + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1) + else: + per_token_logps = [] + for row_logits, row_labels in zip(logits, index, strict=True): + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + + if squeeze: + per_token_logps = per_token_logps.squeeze(-1) + + return per_token_logps + + +@triton.jit +def _selective_logsoftmax_fwd_kernel( + logits_ptr, + index_ptr, + output_ptr, + logsumexp_ptr, + stride_logits_row, + stride_index_row, + stride_output_row, + actual_K, + K_BLOCK: tl.constexpr, + V: tl.constexpr, + BLOCK_V: tl.constexpr, +): + """Forward: online logsumexp + gather. Saves logsumexp for backward.""" + row = tl.program_id(0) + logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row + + # Online logsumexp + running_max = tl.full([], float("-inf"), dtype=tl.float32) + running_sum_exp = tl.full([], 0.0, dtype=tl.float32) + + for v_start in range(0, V, BLOCK_V): + offs = v_start + tl.arange(0, BLOCK_V) + mask = offs < V + x = tl.load(logits_row_ptr + offs, mask=mask, other=float("-inf")).to( + tl.float32 + ) + + block_max = tl.max(x, axis=0) + new_max = tl.maximum(running_max, block_max) + running_sum_exp = running_sum_exp * tl.exp(running_max - new_max) + + exp_x = tl.exp(x - new_max) + exp_x = tl.where(mask, exp_x, 0.0) + running_sum_exp += tl.sum(exp_x, axis=0) + running_max = new_max + + lse = tl.log(running_sum_exp) + running_max + tl.store(logsumexp_ptr + row, lse) + + # Gather and subtract + index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row + output_row_ptr = output_ptr + tl.cast(row, tl.int64) * stride_output_row + + k_offs = tl.arange(0, K_BLOCK) + k_mask = k_offs < actual_K + indices = tl.load(index_row_ptr + k_offs, mask=k_mask, other=0).to(tl.int64) + valid_mask = k_mask & (indices >= 0) & (indices < V) + safe_indices = tl.where(valid_mask, indices, 0) + selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask, other=0.0).to( + tl.float32 + ) + tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask) + + +@triton.jit +def _selective_logsoftmax_bwd_kernel( + grad_output_ptr, + logits_ptr, + index_ptr, + logsumexp_ptr, + grad_logits_ptr, + stride_grad_out_row, + stride_logits_row, + stride_index_row, + stride_grad_logits_row, + actual_K, + K_BLOCK: tl.constexpr, + V: tl.constexpr, + BLOCK_V: tl.constexpr, +): + """Backward: d_logits[j] = -softmax(x)[j] * sum(grad_out) + (grad_out[k] if j == index[k]). + + Single fused pass over V. For each tile, computes the base gradient and adds + scatter contributions inline by checking which indices fall in the current tile. + No separate scatter pass — no read-after-write issues. + """ + row = tl.program_id(0) + logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row + grad_logits_row_ptr = ( + grad_logits_ptr + tl.cast(row, tl.int64) * stride_grad_logits_row + ) + grad_out_row_ptr = grad_output_ptr + tl.cast(row, tl.int64) * stride_grad_out_row + index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row + + lse = tl.load(logsumexp_ptr + row).to(tl.float32) + + # Load grad_output and indices (K_BLOCK elements, masked) + k_offs = tl.arange(0, K_BLOCK) + k_mask = k_offs < actual_K + grad_out = tl.load(grad_out_row_ptr + k_offs, mask=k_mask, other=0.0).to(tl.float32) + indices = tl.load( + index_row_ptr + k_offs, mask=k_mask, other=-1 + ) # -1 = never matches + valid_mask = k_mask & (indices >= 0) & (indices < V) + grad_out = tl.where(valid_mask, grad_out, 0.0) + indices = tl.where(valid_mask, indices, -1) + grad_sum = tl.sum(grad_out, axis=0) + + # Fused pass: for each tile, compute -softmax * grad_sum + scatter + for v_start in range(0, V, BLOCK_V): + offs = v_start + tl.arange(0, BLOCK_V) # [BLOCK_V] + mask = offs < V + x = tl.load(logits_row_ptr + offs, mask=mask, other=0.0).to(tl.float32) + softmax_j = tl.exp(x - lse) + softmax_j = tl.where(mask, softmax_j, 0.0) + grad_j = -softmax_j * grad_sum + + # Scatter: check which selected indices fall in this tile + # offs: [BLOCK_V], indices: [K_BLOCK] + # Broadcast: offs[:, None] == indices[None, :] → [BLOCK_V, K_BLOCK] + match = offs[:, None] == indices[None, :] # [BLOCK_V, K_BLOCK] + # Sum grad_out contributions: for each position j, sum grad_out[k] where index[k]==j + scatter_contrib = tl.sum( + tl.where(match, grad_out[None, :], 0.0), axis=1 + ) # [BLOCK_V] + grad_j += scatter_contrib + + tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask) + + +class _SelectiveLogSoftmaxTriton(torch.autograd.Function): + @staticmethod + def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID): + N = flat_logits.shape[0] + output = torch.empty(N, K_BLOCK, device=flat_logits.device, dtype=torch.float32) + logsumexp = torch.empty(N, device=flat_logits.device, dtype=torch.float32) + + for start in range(0, N, MAX_GRID): + n_rows = min(MAX_GRID, N - start) + _selective_logsoftmax_fwd_kernel[(n_rows,)]( + flat_logits[start], + flat_index[start], + output[start], + logsumexp[start], + flat_logits.stride(0), + flat_index.stride(0), + output.stride(0), + K, + K_BLOCK=K_BLOCK, + V=V, + BLOCK_V=BLOCK_V, + ) + + ctx.save_for_backward(flat_logits, flat_index, logsumexp) + ctx.K = K + ctx.K_BLOCK = K_BLOCK + ctx.V = V + ctx.BLOCK_V = BLOCK_V + ctx.MAX_GRID = MAX_GRID + return output + + @staticmethod + def backward(ctx, grad_output): + flat_logits, flat_index, logsumexp = ctx.saved_tensors + K, K_BLOCK, V, BLOCK_V, MAX_GRID = ( + ctx.K, + ctx.K_BLOCK, + ctx.V, + ctx.BLOCK_V, + ctx.MAX_GRID, + ) + N = flat_logits.shape[0] + + grad_logits = torch.empty_like(flat_logits) + + # grad_output may have K_BLOCK cols; backward kernel reads actual_K + grad_output_contig = grad_output.contiguous() + + for start in range(0, N, MAX_GRID): + n_rows = min(MAX_GRID, N - start) + _selective_logsoftmax_bwd_kernel[(n_rows,)]( + grad_output_contig[start], + flat_logits[start], + flat_index[start], + logsumexp[start], + grad_logits[start], + grad_output_contig.stride(0), + flat_logits.stride(0), + flat_index.stride(0), + grad_logits.stride(0), + K, + K_BLOCK=K_BLOCK, + V=V, + BLOCK_V=BLOCK_V, + ) + + # Return grads for: flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID + return grad_logits, None, None, None, None, None, None + + +def selective_log_softmax(logits, index) -> torch.Tensor: + """ + Fused selective_log_softmax with Triton forward+backward kernels. + + Equivalent to: torch.gather(logits.log_softmax(-1), dim=-1, index=index) + """ + squeeze = index.ndim == logits.ndim - 1 + if squeeze: + index = index.unsqueeze(-1) + + if not logits.is_cuda or logits.dtype == torch.float64: + # Triton kernel computes in float32; fall back for float64 and CPU + return selective_log_softmax_original( + logits, index.squeeze(-1) if squeeze else index + ) + + V = logits.shape[-1] + K = index.shape[-1] + original_index_shape = index.shape + + flat_logits = logits.reshape(-1, V).contiguous() + flat_index = index.reshape(-1, K).contiguous() + + BLOCK_V = 4096 + MAX_GRID = 8192 + K_BLOCK = max(1, triton.next_power_of_2(K)) + + output = _SelectiveLogSoftmaxTriton.apply( + flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID + ) + + if K_BLOCK != K: + output = output[:, :K] + + per_token_logps = output.to(logits.dtype).reshape(original_index_shape) + + if squeeze: + per_token_logps = per_token_logps.squeeze(-1) + + return per_token_logps diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py new file mode 100644 index 000000000..ffc8de865 --- /dev/null +++ b/tests/test_triton_kernels.py @@ -0,0 +1,481 @@ +# Copyright 2026 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Unit tests for Triton kernels: entropy_from_logits and selective_log_softmax. + +Adapted from harness/test_entropy.py and harness/test_selective_logsoftmax.py +into proper pytest tests, plus new OOB index safety tests. +""" + +import math + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required for Triton kernels" +) + + +# --------------------------------------------------------------------------- +# Reference implementations +# --------------------------------------------------------------------------- + + +def _ref_entropy(logits): + """Reference entropy via log_softmax (numerically stable).""" + logp = F.log_softmax(logits.float(), dim=-1) + return -(logp.exp() * logp).sum(dim=-1) + + +def _ref_selective_log_softmax(logits, index): + """Reference selective log softmax via PyTorch gather.""" + squeeze = index.ndim == logits.ndim - 1 + if squeeze: + index = index.unsqueeze(-1) + log_probs = F.log_softmax(logits.float(), dim=-1) + result = torch.gather(log_probs, dim=-1, index=index) + if squeeze: + result = result.squeeze(-1) + return result + + +# --------------------------------------------------------------------------- +# entropy_from_logits +# --------------------------------------------------------------------------- + + +class TestEntropyFromLogits: + @pytest.mark.parametrize( + "B,L", + [ + (1, 128), + (1, 2048), + (4, 512), + (8, 256), + (1, 1), + ], + ) + def test_correctness_various_shapes(self, B, L): + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + V = 1024 + torch.manual_seed(42) + logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32) + result = entropy_from_logits(logits) + expected = _ref_entropy(logits) + assert result.shape == (B, L) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_2d_input(self): + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + logits = torch.randn(16, 256, device="cuda", dtype=torch.float32) + result = entropy_from_logits(logits) + expected = _ref_entropy(logits) + assert result.shape == (16,) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_large_vocab(self): + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + V = 32000 + logits = torch.randn(2, V, device="cuda", dtype=torch.float32) + result = entropy_from_logits(logits) + expected = _ref_entropy(logits) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_uniform_distribution(self): + """Uniform logits -> entropy = log(V).""" + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + V = 1024 + logits = torch.zeros(2, V, device="cuda", dtype=torch.float32) + result = entropy_from_logits(logits) + expected_val = math.log(V) + torch.testing.assert_close( + result, + torch.full((2,), expected_val, device="cuda", dtype=torch.float32), + atol=1e-4, + rtol=1e-4, + ) + + def test_peaked_distribution(self): + """One-hot-like logits -> entropy near 0.""" + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + logits = torch.full((2, 128), -100.0, device="cuda", dtype=torch.float32) + logits[:, 0] = 100.0 + result = entropy_from_logits(logits) + assert (result < 1e-3).all() + + def test_bfloat16(self): + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + logits = torch.randn(4, 256, device="cuda", dtype=torch.bfloat16) + result = entropy_from_logits(logits) + expected = _ref_entropy(logits.float()) + assert result.dtype == torch.bfloat16 + torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2) + + def test_float16(self): + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + logits = torch.randn(4, 256, device="cuda", dtype=torch.float16) + result = entropy_from_logits(logits) + expected = _ref_entropy(logits.float()) + assert result.dtype == torch.float16 + torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2) + + def test_non_contiguous_3d_transpose(self): + """Non-contiguous 3D tensor via transpose(0,1).""" + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + V = 256 + raw = torch.randn(32, 4, V, device="cuda", dtype=torch.float32) + logits = raw.transpose(0, 1) # (4, 32, V) non-contiguous + assert not logits.is_contiguous() + result = entropy_from_logits(logits) + expected = _ref_entropy(logits) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_non_contiguous_3d_slice(self): + """Non-contiguous 3D tensor via batch slicing.""" + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + V = 256 + raw = torch.randn(8, 32, V, device="cuda", dtype=torch.float32) + logits = raw[::2] # (4, 32, V) non-contiguous + assert not logits.is_contiguous() + result = entropy_from_logits(logits) + expected = _ref_entropy(logits) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_many_rows_beyond_max_grid(self): + """More rows than MAX_GRID (8192) to test chunked dispatch.""" + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + logits = torch.randn(10000, 128, device="cuda", dtype=torch.float32) + result = entropy_from_logits(logits) + expected = _ref_entropy(logits) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_entropy_non_negative(self): + from axolotl.monkeypatch.trainer.utils import entropy_from_logits + + logits = torch.randn(32, 512, device="cuda", dtype=torch.float32) + result = entropy_from_logits(logits) + assert (result >= -1e-5).all(), f"Negative entropy: {result.min()}" + + +# --------------------------------------------------------------------------- +# selective_log_softmax — forward correctness +# --------------------------------------------------------------------------- + + +class TestSelectiveLogSoftmax: + @pytest.mark.parametrize( + "B,L,K", + [ + (1, 128, 1), + (4, 512, 1), + (8, 256, 1), + (4, 256, 4), + (4, 256, 7), + (15, 129, 1), # non-power-of-2 + ], + ) + def test_correctness_various_shapes(self, B, L, K): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 1024 + torch.manual_seed(42) + logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32) + if K == 1: + index = torch.randint(0, V, (B, L), device="cuda") + else: + index = torch.randint(0, V, (B, L, K), device="cuda") + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_squeezed_index(self): + """Index with ndim == logits.ndim - 1 triggers squeeze path.""" + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 256 + logits = torch.randn(8, V, device="cuda", dtype=torch.float32) + index = torch.randint(0, V, (8,), device="cuda") + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + assert result.shape == (8,) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_large_vocab(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 32000 + logits = torch.randn(2, V, device="cuda", dtype=torch.float32) + index = torch.randint(0, V, (2, 1), device="cuda") + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_bfloat16(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 1024 + torch.manual_seed(42) + logits = torch.randn(4, 128, V, device="cuda", dtype=torch.bfloat16) + index = torch.randint(0, V, (4, 128), device="cuda") + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits.float(), index) + assert result.dtype == torch.bfloat16 + torch.testing.assert_close(result.float(), expected, atol=0.1, rtol=0.1) + + def test_fp32_tight_tolerance(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 1024 + torch.manual_seed(42) + logits = torch.randn(2, 256, V, device="cuda", dtype=torch.float32) + index = torch.randint(0, V, (2, 256), device="cuda") + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + def test_all_same_index(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn(8, V, device="cuda", dtype=torch.float32) + index = torch.zeros(8, 1, device="cuda", dtype=torch.long) + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_last_index(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn(8, V, device="cuda", dtype=torch.float32) + index = torch.full((8, 1), V - 1, device="cuda", dtype=torch.long) + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_output_always_nonpositive(self): + """Log softmax values should always be <= 0.""" + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 256 + logits = torch.randn(32, V, device="cuda", dtype=torch.float32) + index = torch.randint(0, V, (32, 1), device="cuda") + result = selective_log_softmax(logits, index) + assert (result <= 1e-5).all(), f"Positive log-prob: {result.max()}" + + def test_many_rows_beyond_max_grid(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn(10000, V, device="cuda", dtype=torch.float32) + index = torch.randint(0, V, (10000, 1), device="cuda") + result = selective_log_softmax(logits, index) + expected = _ref_selective_log_softmax(logits, index) + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + +# --------------------------------------------------------------------------- +# selective_log_softmax — backward / gradient correctness +# --------------------------------------------------------------------------- + + +class TestSelectiveLogSoftmaxBackward: + @pytest.mark.parametrize( + "B,L,V,K", + [ + (2, 16, 64, 1), + (2, 16, 64, 4), + (1, 8, 128, 1), + (2, 8, 128, 7), + ], + ) + def test_gradient_matches_reference(self, B, L, V, K): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + torch.manual_seed(42) + logits_ref = torch.randn( + B, L, V, device="cuda", dtype=torch.float32, requires_grad=True + ) + logits_tri = logits_ref.detach().clone().requires_grad_(True) + + if K == 1: + index = torch.randint(0, V, (B, L), device="cuda") + else: + index = torch.randint(0, V, (B, L, K), device="cuda") + + ref_out = _ref_selective_log_softmax(logits_ref, index) + tri_out = selective_log_softmax(logits_tri, index) + + ref_out.sum().backward() + tri_out.sum().backward() + + torch.testing.assert_close( + logits_tri.grad, logits_ref.grad, atol=1e-5, rtol=1e-5 + ) + + def test_gradient_bfloat16_full_vocab(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 4096 + torch.manual_seed(42) + logits_ref = torch.randn( + 2, 64, V, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + logits_tri = logits_ref.detach().clone().requires_grad_(True) + index = torch.randint(0, V, (2, 64), device="cuda") + + _ref_selective_log_softmax(logits_ref, index).sum().backward() + selective_log_softmax(logits_tri, index).sum().backward() + + torch.testing.assert_close( + logits_tri.grad.float(), logits_ref.grad.float(), atol=0.1, rtol=0.1 + ) + + def test_gradient_k1_squeezed(self): + """Gradient with squeezed (1D) index.""" + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 256 + logits = torch.randn( + 8, V, device="cuda", dtype=torch.float32, requires_grad=True + ) + index = torch.randint(0, V, (8,), device="cuda") + + result = selective_log_softmax(logits, index) + result.sum().backward() + triton_grad = logits.grad.clone() + + logits.grad = None + ref = torch.gather( + F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1) + ).squeeze(-1) + ref.sum().backward() + + torch.testing.assert_close(triton_grad, logits.grad, atol=1e-4, rtol=1e-4) + + +# --------------------------------------------------------------------------- +# selective_log_softmax — out-of-bounds index safety +# --------------------------------------------------------------------------- + + +class TestSelectiveLogSoftmaxOOBSafety: + """Verify that out-of-range indices don't crash or corrupt valid results.""" + + def test_negative_indices_no_crash(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn(4, V, device="cuda", dtype=torch.float32) + index = torch.tensor( + [[-1], [0], [V - 1], [-5]], device="cuda", dtype=torch.long + ) + result = selective_log_softmax(logits, index) + assert result.shape == (4, 1) + # Valid rows should be finite and match reference + valid_idx = torch.tensor([[0], [V - 1]], device="cuda", dtype=torch.long) + valid_logits = logits[1:3] + expected = _ref_selective_log_softmax(valid_logits, valid_idx) + torch.testing.assert_close(result[1:3], expected, atol=1e-4, rtol=1e-4) + + def test_index_exceeds_vocab_no_crash(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn(4, V, device="cuda", dtype=torch.float32) + index = torch.tensor( + [[0], [V], [V + 100], [V - 1]], device="cuda", dtype=torch.long + ) + result = selective_log_softmax(logits, index) + assert result.shape == (4, 1) + # Valid rows (0 and 3) should match reference + for row_idx, idx_val in [(0, 0), (3, V - 1)]: + ref = _ref_selective_log_softmax( + logits[row_idx : row_idx + 1], + torch.tensor([[idx_val]], device="cuda", dtype=torch.long), + ) + torch.testing.assert_close( + result[row_idx : row_idx + 1], ref, atol=1e-4, rtol=1e-4 + ) + + def test_mixed_valid_invalid_multi_index(self): + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 256 + K = 3 + logits = torch.randn(4, V, device="cuda", dtype=torch.float32) + index = torch.tensor( + [ + [0, 10, -1], # last invalid + [V, 5, 100], # first invalid + [50, 60, 70], # all valid + [-1, V + 1, -100], # all invalid + ], + device="cuda", + dtype=torch.long, + ) + result = selective_log_softmax(logits, index) + assert result.shape == (4, K) + # Row 2 (all valid) must match reference exactly + valid_index = torch.tensor([[50, 60, 70]], device="cuda", dtype=torch.long) + expected = _ref_selective_log_softmax(logits[2:3], valid_index) + torch.testing.assert_close(result[2:3], expected, atol=1e-4, rtol=1e-4) + + def test_oob_backward_no_crash(self): + """Backward with OOB indices should not crash and grads should be finite.""" + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn( + 4, V, device="cuda", dtype=torch.float32, requires_grad=True + ) + index = torch.tensor( + [[-1], [0], [V + 10], [V - 1]], device="cuda", dtype=torch.long + ) + result = selective_log_softmax(logits, index) + result.sum().backward() + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + + def test_oob_backward_valid_rows_correct(self): + """Gradients for valid-index rows should match reference even when other rows have OOB.""" + from axolotl.monkeypatch.trainer.utils import selective_log_softmax + + V = 128 + logits = torch.randn( + 4, V, device="cuda", dtype=torch.float32, requires_grad=True + ) + # Row 0: invalid, Row 1: valid, Row 2: invalid, Row 3: valid + index = torch.tensor( + [[-1], [42], [V + 5], [100]], device="cuda", dtype=torch.long + ) + result = selective_log_softmax(logits, index) + result.sum().backward() + + # Compute reference gradient for valid rows only + logits_ref = logits.detach().clone().requires_grad_(True) + valid_rows = [1, 3] + valid_indices = [42, 100] + for r, idx in zip(valid_rows, valid_indices, strict=True): + ref_lp = F.log_softmax(logits_ref[r : r + 1], dim=-1) + ref_val = ref_lp[0, idx] + ref_val.backward(retain_graph=True) + + for r in valid_rows: + torch.testing.assert_close( + logits.grad[r], logits_ref.grad[r], atol=1e-4, rtol=1e-4 + )