use custom triton kernels for entropy from logits and selective softmax (#3510)
* use custom triton kernels for entropy from logits and selective softmax * PR comments fixes * fix out of bounds, include tests, include benchmarks * chore: lint
This commit is contained in:
208
benchmarks/bench_entropy.py
Normal file
208
benchmarks/bench_entropy.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
|
||||
"""Original chunked implementation (reference)."""
|
||||
original_shape = logits.shape[:-1]
|
||||
num_classes = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, num_classes)
|
||||
entropies = []
|
||||
for chunk in flat_logits.split(chunk_size, dim=0):
|
||||
logps = F.log_softmax(chunk, dim=-1)
|
||||
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
|
||||
entropies.append(chunk_entropy)
|
||||
return torch.cat(entropies, dim=0).reshape(original_shape)
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, logits, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
out = fn(logits, chunk_size=128)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
del out
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, logits, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(logits, chunk_size=128)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_contiguous():
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(1, 16384),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
t_orig = profile_time(entropy_from_logits_original, logits)
|
||||
t_triton = profile_time(entropy_from_logits, logits)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(entropy_from_logits_original, logits)
|
||||
m_triton = profile_memory(entropy_from_logits, logits)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_noncontiguous():
|
||||
print("\n" + "=" * 60)
|
||||
print(
|
||||
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(4, 2048, "transpose"),
|
||||
(4, 8192, "transpose"),
|
||||
(8, 2048, "transpose"),
|
||||
(4, 4096, "slice_batch"),
|
||||
]
|
||||
|
||||
for B, L, method in configs:
|
||||
torch.manual_seed(42)
|
||||
|
||||
if method == "transpose":
|
||||
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw.transpose(0, 1)
|
||||
raw_gb = L * B * V * 2 / 1e9
|
||||
elif method == "slice_batch":
|
||||
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw[::2]
|
||||
raw_gb = B * 2 * L * V * 2 / 1e9
|
||||
else:
|
||||
continue
|
||||
|
||||
if raw_gb > 28:
|
||||
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
|
||||
del raw, logits_nc
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
def original_with_copy(logits, chunk_size=128):
|
||||
return entropy_from_logits_original(
|
||||
logits.contiguous(), chunk_size=chunk_size
|
||||
)
|
||||
|
||||
t_orig = profile_time(original_with_copy, logits_nc)
|
||||
t_triton = profile_time(entropy_from_logits, logits_nc)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" orig+copy: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton-strided:{fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(original_with_copy, logits_nc)
|
||||
m_triton = profile_memory(entropy_from_logits, logits_nc)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" orig+copy: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton-strided:{fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del raw, logits_nc
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_contiguous()
|
||||
benchmark_noncontiguous()
|
||||
191
benchmarks/bench_selective_logsoftmax.py
Normal file
191
benchmarks/bench_selective_logsoftmax.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
selective_log_softmax,
|
||||
selective_log_softmax_original,
|
||||
)
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, args, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
fn(*args)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, args, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(*args)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_forward():
|
||||
print("=" * 60)
|
||||
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(selective_log_softmax_original, (logits, index))
|
||||
t_triton = profile_time(selective_log_softmax, (logits, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
|
||||
m_triton = profile_memory(selective_log_softmax, (logits, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_backward():
|
||||
print("\n" + "=" * 60)
|
||||
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
def fwd_bwd_original(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax_original(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
def fwd_bwd_triton(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 20:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits_orig = torch.randn(
|
||||
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_orig.detach().clone().requires_grad_(True)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
|
||||
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" FWD+BWD TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
|
||||
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" FWD+BWD MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits_orig, logits_tri, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_forward()
|
||||
benchmark_backward()
|
||||
@@ -117,6 +117,7 @@ class PatchManager:
|
||||
self._apply_voxtral_patches()
|
||||
self._apply_apertus_patches()
|
||||
self._apply_trl_vllm_patches()
|
||||
self._apply_trl_trainer_utils_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
@@ -679,6 +680,39 @@ class PatchManager:
|
||||
|
||||
patch_trl_vllm()
|
||||
|
||||
def _apply_trl_trainer_utils_patches(self):
|
||||
"""Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels."""
|
||||
if not self.cfg.rl:
|
||||
return
|
||||
|
||||
try:
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
entropy_from_logits,
|
||||
selective_log_softmax,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
LOG.warning("Triton not available — skipping trl.trainer.utils patches")
|
||||
return
|
||||
|
||||
import trl.trainer.utils
|
||||
|
||||
# Guard against repeated calls: only stash the original if trl still
|
||||
# points at its own implementation (not our wrapper).
|
||||
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
|
||||
from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils
|
||||
|
||||
_axolotl_trainer_utils.selective_log_softmax_original = (
|
||||
trl.trainer.utils.selective_log_softmax
|
||||
)
|
||||
trl.trainer.utils.selective_log_softmax = selective_log_softmax
|
||||
|
||||
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
|
||||
trl.trainer.utils.entropy_from_logits = entropy_from_logits
|
||||
|
||||
LOG.info(
|
||||
"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits"
|
||||
)
|
||||
|
||||
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
|
||||
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
|
||||
if self.cfg.scaling_softmax:
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .utils import entropy_from_logits, selective_log_softmax
|
||||
|
||||
__all__ = ["entropy_from_logits", "selective_log_softmax"]
|
||||
|
||||
429
src/axolotl/monkeypatch/trainer/utils.py
Normal file
429
src/axolotl/monkeypatch/trainer/utils.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _entropy_online_kernel(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
stride_row,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Online entropy: single pass with running max correction."""
|
||||
row = tl.program_id(0)
|
||||
row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row
|
||||
|
||||
running_max = tl.full([], float("-inf"), dtype=tl.float32)
|
||||
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
|
||||
running_weighted = tl.full([], 0.0, dtype=tl.float32)
|
||||
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = offs < V
|
||||
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
|
||||
|
||||
block_max = tl.max(x, axis=0)
|
||||
new_max = tl.maximum(running_max, block_max)
|
||||
|
||||
correction = tl.exp(running_max - new_max)
|
||||
running_sum_exp = running_sum_exp * correction
|
||||
running_weighted = running_weighted * correction
|
||||
|
||||
exp_x = tl.exp(x - new_max)
|
||||
exp_x = tl.where(mask, exp_x, 0.0)
|
||||
x = tl.where(mask, x, 0.0)
|
||||
running_sum_exp += tl.sum(exp_x, axis=0)
|
||||
running_weighted += tl.sum(exp_x * x, axis=0)
|
||||
|
||||
running_max = new_max
|
||||
|
||||
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
|
||||
tl.store(output_ptr + row, entropy)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _entropy_online_kernel_strided(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
stride_outer,
|
||||
stride_inner,
|
||||
n_inner,
|
||||
row_offset,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Online entropy for non-contiguous 3D (B, L, V) tensors."""
|
||||
local_row = tl.program_id(0)
|
||||
row = local_row + row_offset
|
||||
outer_idx = row // n_inner
|
||||
inner_idx = row % n_inner
|
||||
off = outer_idx.to(tl.int64) * stride_outer + inner_idx.to(tl.int64) * stride_inner
|
||||
row_ptr = logits_ptr + off
|
||||
|
||||
running_max = tl.full([], float("-inf"), dtype=tl.float32)
|
||||
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
|
||||
running_weighted = tl.full([], 0.0, dtype=tl.float32)
|
||||
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = offs < V
|
||||
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
|
||||
|
||||
block_max = tl.max(x, axis=0)
|
||||
new_max = tl.maximum(running_max, block_max)
|
||||
|
||||
correction = tl.exp(running_max - new_max)
|
||||
running_sum_exp = running_sum_exp * correction
|
||||
running_weighted = running_weighted * correction
|
||||
|
||||
exp_x = tl.exp(x - new_max)
|
||||
exp_x = tl.where(mask, exp_x, 0.0)
|
||||
x = tl.where(mask, x, 0.0)
|
||||
running_sum_exp += tl.sum(exp_x, axis=0)
|
||||
running_weighted += tl.sum(exp_x * x, axis=0)
|
||||
|
||||
running_max = new_max
|
||||
|
||||
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
|
||||
tl.store(output_ptr + local_row, entropy)
|
||||
|
||||
|
||||
def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
|
||||
"""Triton-fused entropy (online single-pass). Handles non-contiguous tensors without copying."""
|
||||
original_shape = logits.shape[:-1]
|
||||
V = logits.shape[-1]
|
||||
N = 1
|
||||
for s in original_shape:
|
||||
N *= s
|
||||
|
||||
if not logits.is_cuda:
|
||||
# CPU fallback: stable entropy via log_softmax
|
||||
logp = F.log_softmax(logits.float(), dim=-1)
|
||||
ent = -(logp.exp() * logp).sum(dim=-1)
|
||||
return ent.to(logits.dtype).reshape(original_shape)
|
||||
|
||||
output = torch.empty(N, device=logits.device, dtype=torch.float32)
|
||||
|
||||
BLOCK_V = 4096
|
||||
MAX_GRID_CONTIG = 8192
|
||||
MAX_GRID_STRIDED = 2048
|
||||
|
||||
# Vocab (last) dim must be contiguous for coalesced loads
|
||||
if logits.stride(-1) != 1:
|
||||
logits = logits.contiguous()
|
||||
|
||||
if logits.is_contiguous():
|
||||
flat_logits = logits.reshape(-1, V)
|
||||
stride = flat_logits.stride(0)
|
||||
for start in range(0, N, MAX_GRID_CONTIG):
|
||||
n_rows = min(MAX_GRID_CONTIG, N - start)
|
||||
_entropy_online_kernel[(n_rows,)](
|
||||
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
|
||||
)
|
||||
elif logits.ndim == 3:
|
||||
stride_outer = logits.stride(0)
|
||||
stride_inner = logits.stride(1)
|
||||
n_inner = logits.shape[1]
|
||||
for start in range(0, N, MAX_GRID_STRIDED):
|
||||
n_rows = min(MAX_GRID_STRIDED, N - start)
|
||||
_entropy_online_kernel_strided[(n_rows,)](
|
||||
logits,
|
||||
output[start],
|
||||
stride_outer,
|
||||
stride_inner,
|
||||
n_inner,
|
||||
start,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
else:
|
||||
logits = logits.contiguous()
|
||||
flat_logits = logits.reshape(-1, V)
|
||||
stride = flat_logits.stride(0)
|
||||
for start in range(0, N, MAX_GRID_CONTIG):
|
||||
n_rows = min(MAX_GRID_CONTIG, N - start)
|
||||
_entropy_online_kernel[(n_rows,)](
|
||||
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
|
||||
)
|
||||
|
||||
return output.to(logits.dtype).reshape(original_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# selective_log_softmax — fused forward + backward Triton kernels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def selective_log_softmax_original(logits, index) -> torch.Tensor:
|
||||
"""Original selective_log_softmax (reference/fallback)."""
|
||||
squeeze = index.ndim == logits.ndim - 1
|
||||
if squeeze:
|
||||
index = index.unsqueeze(-1)
|
||||
|
||||
if logits.dtype in [torch.float32, torch.float64]:
|
||||
selected_logits = torch.gather(logits, dim=-1, index=index)
|
||||
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1)
|
||||
else:
|
||||
per_token_logps = []
|
||||
for row_logits, row_labels in zip(logits, index, strict=True):
|
||||
row_logps = F.log_softmax(row_logits, dim=-1)
|
||||
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels)
|
||||
per_token_logps.append(row_per_token_logps)
|
||||
per_token_logps = torch.stack(per_token_logps)
|
||||
|
||||
if squeeze:
|
||||
per_token_logps = per_token_logps.squeeze(-1)
|
||||
|
||||
return per_token_logps
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_logsoftmax_fwd_kernel(
|
||||
logits_ptr,
|
||||
index_ptr,
|
||||
output_ptr,
|
||||
logsumexp_ptr,
|
||||
stride_logits_row,
|
||||
stride_index_row,
|
||||
stride_output_row,
|
||||
actual_K,
|
||||
K_BLOCK: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Forward: online logsumexp + gather. Saves logsumexp for backward."""
|
||||
row = tl.program_id(0)
|
||||
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
|
||||
|
||||
# Online logsumexp
|
||||
running_max = tl.full([], float("-inf"), dtype=tl.float32)
|
||||
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
|
||||
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = offs < V
|
||||
x = tl.load(logits_row_ptr + offs, mask=mask, other=float("-inf")).to(
|
||||
tl.float32
|
||||
)
|
||||
|
||||
block_max = tl.max(x, axis=0)
|
||||
new_max = tl.maximum(running_max, block_max)
|
||||
running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)
|
||||
|
||||
exp_x = tl.exp(x - new_max)
|
||||
exp_x = tl.where(mask, exp_x, 0.0)
|
||||
running_sum_exp += tl.sum(exp_x, axis=0)
|
||||
running_max = new_max
|
||||
|
||||
lse = tl.log(running_sum_exp) + running_max
|
||||
tl.store(logsumexp_ptr + row, lse)
|
||||
|
||||
# Gather and subtract
|
||||
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
|
||||
output_row_ptr = output_ptr + tl.cast(row, tl.int64) * stride_output_row
|
||||
|
||||
k_offs = tl.arange(0, K_BLOCK)
|
||||
k_mask = k_offs < actual_K
|
||||
indices = tl.load(index_row_ptr + k_offs, mask=k_mask, other=0).to(tl.int64)
|
||||
valid_mask = k_mask & (indices >= 0) & (indices < V)
|
||||
safe_indices = tl.where(valid_mask, indices, 0)
|
||||
selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_logsoftmax_bwd_kernel(
|
||||
grad_output_ptr,
|
||||
logits_ptr,
|
||||
index_ptr,
|
||||
logsumexp_ptr,
|
||||
grad_logits_ptr,
|
||||
stride_grad_out_row,
|
||||
stride_logits_row,
|
||||
stride_index_row,
|
||||
stride_grad_logits_row,
|
||||
actual_K,
|
||||
K_BLOCK: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Backward: d_logits[j] = -softmax(x)[j] * sum(grad_out) + (grad_out[k] if j == index[k]).
|
||||
|
||||
Single fused pass over V. For each tile, computes the base gradient and adds
|
||||
scatter contributions inline by checking which indices fall in the current tile.
|
||||
No separate scatter pass — no read-after-write issues.
|
||||
"""
|
||||
row = tl.program_id(0)
|
||||
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
|
||||
grad_logits_row_ptr = (
|
||||
grad_logits_ptr + tl.cast(row, tl.int64) * stride_grad_logits_row
|
||||
)
|
||||
grad_out_row_ptr = grad_output_ptr + tl.cast(row, tl.int64) * stride_grad_out_row
|
||||
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
|
||||
|
||||
lse = tl.load(logsumexp_ptr + row).to(tl.float32)
|
||||
|
||||
# Load grad_output and indices (K_BLOCK elements, masked)
|
||||
k_offs = tl.arange(0, K_BLOCK)
|
||||
k_mask = k_offs < actual_K
|
||||
grad_out = tl.load(grad_out_row_ptr + k_offs, mask=k_mask, other=0.0).to(tl.float32)
|
||||
indices = tl.load(
|
||||
index_row_ptr + k_offs, mask=k_mask, other=-1
|
||||
) # -1 = never matches
|
||||
valid_mask = k_mask & (indices >= 0) & (indices < V)
|
||||
grad_out = tl.where(valid_mask, grad_out, 0.0)
|
||||
indices = tl.where(valid_mask, indices, -1)
|
||||
grad_sum = tl.sum(grad_out, axis=0)
|
||||
|
||||
# Fused pass: for each tile, compute -softmax * grad_sum + scatter
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V) # [BLOCK_V]
|
||||
mask = offs < V
|
||||
x = tl.load(logits_row_ptr + offs, mask=mask, other=0.0).to(tl.float32)
|
||||
softmax_j = tl.exp(x - lse)
|
||||
softmax_j = tl.where(mask, softmax_j, 0.0)
|
||||
grad_j = -softmax_j * grad_sum
|
||||
|
||||
# Scatter: check which selected indices fall in this tile
|
||||
# offs: [BLOCK_V], indices: [K_BLOCK]
|
||||
# Broadcast: offs[:, None] == indices[None, :] → [BLOCK_V, K_BLOCK]
|
||||
match = offs[:, None] == indices[None, :] # [BLOCK_V, K_BLOCK]
|
||||
# Sum grad_out contributions: for each position j, sum grad_out[k] where index[k]==j
|
||||
scatter_contrib = tl.sum(
|
||||
tl.where(match, grad_out[None, :], 0.0), axis=1
|
||||
) # [BLOCK_V]
|
||||
grad_j += scatter_contrib
|
||||
|
||||
tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)
|
||||
|
||||
|
||||
class _SelectiveLogSoftmaxTriton(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID):
|
||||
N = flat_logits.shape[0]
|
||||
output = torch.empty(N, K_BLOCK, device=flat_logits.device, dtype=torch.float32)
|
||||
logsumexp = torch.empty(N, device=flat_logits.device, dtype=torch.float32)
|
||||
|
||||
for start in range(0, N, MAX_GRID):
|
||||
n_rows = min(MAX_GRID, N - start)
|
||||
_selective_logsoftmax_fwd_kernel[(n_rows,)](
|
||||
flat_logits[start],
|
||||
flat_index[start],
|
||||
output[start],
|
||||
logsumexp[start],
|
||||
flat_logits.stride(0),
|
||||
flat_index.stride(0),
|
||||
output.stride(0),
|
||||
K,
|
||||
K_BLOCK=K_BLOCK,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(flat_logits, flat_index, logsumexp)
|
||||
ctx.K = K
|
||||
ctx.K_BLOCK = K_BLOCK
|
||||
ctx.V = V
|
||||
ctx.BLOCK_V = BLOCK_V
|
||||
ctx.MAX_GRID = MAX_GRID
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
flat_logits, flat_index, logsumexp = ctx.saved_tensors
|
||||
K, K_BLOCK, V, BLOCK_V, MAX_GRID = (
|
||||
ctx.K,
|
||||
ctx.K_BLOCK,
|
||||
ctx.V,
|
||||
ctx.BLOCK_V,
|
||||
ctx.MAX_GRID,
|
||||
)
|
||||
N = flat_logits.shape[0]
|
||||
|
||||
grad_logits = torch.empty_like(flat_logits)
|
||||
|
||||
# grad_output may have K_BLOCK cols; backward kernel reads actual_K
|
||||
grad_output_contig = grad_output.contiguous()
|
||||
|
||||
for start in range(0, N, MAX_GRID):
|
||||
n_rows = min(MAX_GRID, N - start)
|
||||
_selective_logsoftmax_bwd_kernel[(n_rows,)](
|
||||
grad_output_contig[start],
|
||||
flat_logits[start],
|
||||
flat_index[start],
|
||||
logsumexp[start],
|
||||
grad_logits[start],
|
||||
grad_output_contig.stride(0),
|
||||
flat_logits.stride(0),
|
||||
flat_index.stride(0),
|
||||
grad_logits.stride(0),
|
||||
K,
|
||||
K_BLOCK=K_BLOCK,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
# Return grads for: flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index) -> torch.Tensor:
|
||||
"""
|
||||
Fused selective_log_softmax with Triton forward+backward kernels.
|
||||
|
||||
Equivalent to: torch.gather(logits.log_softmax(-1), dim=-1, index=index)
|
||||
"""
|
||||
squeeze = index.ndim == logits.ndim - 1
|
||||
if squeeze:
|
||||
index = index.unsqueeze(-1)
|
||||
|
||||
if not logits.is_cuda or logits.dtype == torch.float64:
|
||||
# Triton kernel computes in float32; fall back for float64 and CPU
|
||||
return selective_log_softmax_original(
|
||||
logits, index.squeeze(-1) if squeeze else index
|
||||
)
|
||||
|
||||
V = logits.shape[-1]
|
||||
K = index.shape[-1]
|
||||
original_index_shape = index.shape
|
||||
|
||||
flat_logits = logits.reshape(-1, V).contiguous()
|
||||
flat_index = index.reshape(-1, K).contiguous()
|
||||
|
||||
BLOCK_V = 4096
|
||||
MAX_GRID = 8192
|
||||
K_BLOCK = max(1, triton.next_power_of_2(K))
|
||||
|
||||
output = _SelectiveLogSoftmaxTriton.apply(
|
||||
flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
|
||||
)
|
||||
|
||||
if K_BLOCK != K:
|
||||
output = output[:, :K]
|
||||
|
||||
per_token_logps = output.to(logits.dtype).reshape(original_index_shape)
|
||||
|
||||
if squeeze:
|
||||
per_token_logps = per_token_logps.squeeze(-1)
|
||||
|
||||
return per_token_logps
|
||||
481
tests/test_triton_kernels.py
Normal file
481
tests/test_triton_kernels.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""Unit tests for Triton kernels: entropy_from_logits and selective_log_softmax.
|
||||
|
||||
Adapted from harness/test_entropy.py and harness/test_selective_logsoftmax.py
|
||||
into proper pytest tests, plus new OOB index safety tests.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="CUDA required for Triton kernels"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reference implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ref_entropy(logits):
|
||||
"""Reference entropy via log_softmax (numerically stable)."""
|
||||
logp = F.log_softmax(logits.float(), dim=-1)
|
||||
return -(logp.exp() * logp).sum(dim=-1)
|
||||
|
||||
|
||||
def _ref_selective_log_softmax(logits, index):
|
||||
"""Reference selective log softmax via PyTorch gather."""
|
||||
squeeze = index.ndim == logits.ndim - 1
|
||||
if squeeze:
|
||||
index = index.unsqueeze(-1)
|
||||
log_probs = F.log_softmax(logits.float(), dim=-1)
|
||||
result = torch.gather(log_probs, dim=-1, index=index)
|
||||
if squeeze:
|
||||
result = result.squeeze(-1)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# entropy_from_logits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntropyFromLogits:
|
||||
@pytest.mark.parametrize(
|
||||
"B,L",
|
||||
[
|
||||
(1, 128),
|
||||
(1, 2048),
|
||||
(4, 512),
|
||||
(8, 256),
|
||||
(1, 1),
|
||||
],
|
||||
)
|
||||
def test_correctness_various_shapes(self, B, L):
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 1024
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits)
|
||||
assert result.shape == (B, L)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_2d_input(self):
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
logits = torch.randn(16, 256, device="cuda", dtype=torch.float32)
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits)
|
||||
assert result.shape == (16,)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_large_vocab(self):
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 32000
|
||||
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_uniform_distribution(self):
|
||||
"""Uniform logits -> entropy = log(V)."""
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 1024
|
||||
logits = torch.zeros(2, V, device="cuda", dtype=torch.float32)
|
||||
result = entropy_from_logits(logits)
|
||||
expected_val = math.log(V)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
torch.full((2,), expected_val, device="cuda", dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_peaked_distribution(self):
|
||||
"""One-hot-like logits -> entropy near 0."""
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
logits = torch.full((2, 128), -100.0, device="cuda", dtype=torch.float32)
|
||||
logits[:, 0] = 100.0
|
||||
result = entropy_from_logits(logits)
|
||||
assert (result < 1e-3).all()
|
||||
|
||||
def test_bfloat16(self):
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
logits = torch.randn(4, 256, device="cuda", dtype=torch.bfloat16)
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits.float())
|
||||
assert result.dtype == torch.bfloat16
|
||||
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_float16(self):
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
logits = torch.randn(4, 256, device="cuda", dtype=torch.float16)
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits.float())
|
||||
assert result.dtype == torch.float16
|
||||
torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_non_contiguous_3d_transpose(self):
|
||||
"""Non-contiguous 3D tensor via transpose(0,1)."""
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 256
|
||||
raw = torch.randn(32, 4, V, device="cuda", dtype=torch.float32)
|
||||
logits = raw.transpose(0, 1) # (4, 32, V) non-contiguous
|
||||
assert not logits.is_contiguous()
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_non_contiguous_3d_slice(self):
|
||||
"""Non-contiguous 3D tensor via batch slicing."""
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 256
|
||||
raw = torch.randn(8, 32, V, device="cuda", dtype=torch.float32)
|
||||
logits = raw[::2] # (4, 32, V) non-contiguous
|
||||
assert not logits.is_contiguous()
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_many_rows_beyond_max_grid(self):
|
||||
"""More rows than MAX_GRID (8192) to test chunked dispatch."""
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
logits = torch.randn(10000, 128, device="cuda", dtype=torch.float32)
|
||||
result = entropy_from_logits(logits)
|
||||
expected = _ref_entropy(logits)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_entropy_non_negative(self):
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
logits = torch.randn(32, 512, device="cuda", dtype=torch.float32)
|
||||
result = entropy_from_logits(logits)
|
||||
assert (result >= -1e-5).all(), f"Negative entropy: {result.min()}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# selective_log_softmax — forward correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSelectiveLogSoftmax:
|
||||
@pytest.mark.parametrize(
|
||||
"B,L,K",
|
||||
[
|
||||
(1, 128, 1),
|
||||
(4, 512, 1),
|
||||
(8, 256, 1),
|
||||
(4, 256, 4),
|
||||
(4, 256, 7),
|
||||
(15, 129, 1), # non-power-of-2
|
||||
],
|
||||
)
|
||||
def test_correctness_various_shapes(self, B, L, K):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 1024
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.float32)
|
||||
if K == 1:
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
else:
|
||||
index = torch.randint(0, V, (B, L, K), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_squeezed_index(self):
|
||||
"""Index with ndim == logits.ndim - 1 triggers squeeze path."""
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 256
|
||||
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.randint(0, V, (8,), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
assert result.shape == (8,)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_large_vocab(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 32000
|
||||
logits = torch.randn(2, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.randint(0, V, (2, 1), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_bfloat16(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 1024
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(4, 128, V, device="cuda", dtype=torch.bfloat16)
|
||||
index = torch.randint(0, V, (4, 128), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits.float(), index)
|
||||
assert result.dtype == torch.bfloat16
|
||||
torch.testing.assert_close(result.float(), expected, atol=0.1, rtol=0.1)
|
||||
|
||||
def test_fp32_tight_tolerance(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 1024
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(2, 256, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.randint(0, V, (2, 256), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_all_same_index(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.zeros(8, 1, device="cuda", dtype=torch.long)
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_last_index(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(8, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.full((8, 1), V - 1, device="cuda", dtype=torch.long)
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_output_always_nonpositive(self):
|
||||
"""Log softmax values should always be <= 0."""
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 256
|
||||
logits = torch.randn(32, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.randint(0, V, (32, 1), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
assert (result <= 1e-5).all(), f"Positive log-prob: {result.max()}"
|
||||
|
||||
def test_many_rows_beyond_max_grid(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(10000, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.randint(0, V, (10000, 1), device="cuda")
|
||||
result = selective_log_softmax(logits, index)
|
||||
expected = _ref_selective_log_softmax(logits, index)
|
||||
torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# selective_log_softmax — backward / gradient correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSelectiveLogSoftmaxBackward:
|
||||
@pytest.mark.parametrize(
|
||||
"B,L,V,K",
|
||||
[
|
||||
(2, 16, 64, 1),
|
||||
(2, 16, 64, 4),
|
||||
(1, 8, 128, 1),
|
||||
(2, 8, 128, 7),
|
||||
],
|
||||
)
|
||||
def test_gradient_matches_reference(self, B, L, V, K):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits_ref = torch.randn(
|
||||
B, L, V, device="cuda", dtype=torch.float32, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_ref.detach().clone().requires_grad_(True)
|
||||
|
||||
if K == 1:
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
else:
|
||||
index = torch.randint(0, V, (B, L, K), device="cuda")
|
||||
|
||||
ref_out = _ref_selective_log_softmax(logits_ref, index)
|
||||
tri_out = selective_log_softmax(logits_tri, index)
|
||||
|
||||
ref_out.sum().backward()
|
||||
tri_out.sum().backward()
|
||||
|
||||
torch.testing.assert_close(
|
||||
logits_tri.grad, logits_ref.grad, atol=1e-5, rtol=1e-5
|
||||
)
|
||||
|
||||
def test_gradient_bfloat16_full_vocab(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 4096
|
||||
torch.manual_seed(42)
|
||||
logits_ref = torch.randn(
|
||||
2, 64, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_ref.detach().clone().requires_grad_(True)
|
||||
index = torch.randint(0, V, (2, 64), device="cuda")
|
||||
|
||||
_ref_selective_log_softmax(logits_ref, index).sum().backward()
|
||||
selective_log_softmax(logits_tri, index).sum().backward()
|
||||
|
||||
torch.testing.assert_close(
|
||||
logits_tri.grad.float(), logits_ref.grad.float(), atol=0.1, rtol=0.1
|
||||
)
|
||||
|
||||
def test_gradient_k1_squeezed(self):
|
||||
"""Gradient with squeezed (1D) index."""
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 256
|
||||
logits = torch.randn(
|
||||
8, V, device="cuda", dtype=torch.float32, requires_grad=True
|
||||
)
|
||||
index = torch.randint(0, V, (8,), device="cuda")
|
||||
|
||||
result = selective_log_softmax(logits, index)
|
||||
result.sum().backward()
|
||||
triton_grad = logits.grad.clone()
|
||||
|
||||
logits.grad = None
|
||||
ref = torch.gather(
|
||||
F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
ref.sum().backward()
|
||||
|
||||
torch.testing.assert_close(triton_grad, logits.grad, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# selective_log_softmax — out-of-bounds index safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSelectiveLogSoftmaxOOBSafety:
|
||||
"""Verify that out-of-range indices don't crash or corrupt valid results."""
|
||||
|
||||
def test_negative_indices_no_crash(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.tensor(
|
||||
[[-1], [0], [V - 1], [-5]], device="cuda", dtype=torch.long
|
||||
)
|
||||
result = selective_log_softmax(logits, index)
|
||||
assert result.shape == (4, 1)
|
||||
# Valid rows should be finite and match reference
|
||||
valid_idx = torch.tensor([[0], [V - 1]], device="cuda", dtype=torch.long)
|
||||
valid_logits = logits[1:3]
|
||||
expected = _ref_selective_log_softmax(valid_logits, valid_idx)
|
||||
torch.testing.assert_close(result[1:3], expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_index_exceeds_vocab_no_crash(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.tensor(
|
||||
[[0], [V], [V + 100], [V - 1]], device="cuda", dtype=torch.long
|
||||
)
|
||||
result = selective_log_softmax(logits, index)
|
||||
assert result.shape == (4, 1)
|
||||
# Valid rows (0 and 3) should match reference
|
||||
for row_idx, idx_val in [(0, 0), (3, V - 1)]:
|
||||
ref = _ref_selective_log_softmax(
|
||||
logits[row_idx : row_idx + 1],
|
||||
torch.tensor([[idx_val]], device="cuda", dtype=torch.long),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result[row_idx : row_idx + 1], ref, atol=1e-4, rtol=1e-4
|
||||
)
|
||||
|
||||
def test_mixed_valid_invalid_multi_index(self):
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 256
|
||||
K = 3
|
||||
logits = torch.randn(4, V, device="cuda", dtype=torch.float32)
|
||||
index = torch.tensor(
|
||||
[
|
||||
[0, 10, -1], # last invalid
|
||||
[V, 5, 100], # first invalid
|
||||
[50, 60, 70], # all valid
|
||||
[-1, V + 1, -100], # all invalid
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.long,
|
||||
)
|
||||
result = selective_log_softmax(logits, index)
|
||||
assert result.shape == (4, K)
|
||||
# Row 2 (all valid) must match reference exactly
|
||||
valid_index = torch.tensor([[50, 60, 70]], device="cuda", dtype=torch.long)
|
||||
expected = _ref_selective_log_softmax(logits[2:3], valid_index)
|
||||
torch.testing.assert_close(result[2:3], expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_oob_backward_no_crash(self):
|
||||
"""Backward with OOB indices should not crash and grads should be finite."""
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(
|
||||
4, V, device="cuda", dtype=torch.float32, requires_grad=True
|
||||
)
|
||||
index = torch.tensor(
|
||||
[[-1], [0], [V + 10], [V - 1]], device="cuda", dtype=torch.long
|
||||
)
|
||||
result = selective_log_softmax(logits, index)
|
||||
result.sum().backward()
|
||||
assert logits.grad is not None
|
||||
assert torch.isfinite(logits.grad).all()
|
||||
|
||||
def test_oob_backward_valid_rows_correct(self):
|
||||
"""Gradients for valid-index rows should match reference even when other rows have OOB."""
|
||||
from axolotl.monkeypatch.trainer.utils import selective_log_softmax
|
||||
|
||||
V = 128
|
||||
logits = torch.randn(
|
||||
4, V, device="cuda", dtype=torch.float32, requires_grad=True
|
||||
)
|
||||
# Row 0: invalid, Row 1: valid, Row 2: invalid, Row 3: valid
|
||||
index = torch.tensor(
|
||||
[[-1], [42], [V + 5], [100]], device="cuda", dtype=torch.long
|
||||
)
|
||||
result = selective_log_softmax(logits, index)
|
||||
result.sum().backward()
|
||||
|
||||
# Compute reference gradient for valid rows only
|
||||
logits_ref = logits.detach().clone().requires_grad_(True)
|
||||
valid_rows = [1, 3]
|
||||
valid_indices = [42, 100]
|
||||
for r, idx in zip(valid_rows, valid_indices, strict=True):
|
||||
ref_lp = F.log_softmax(logits_ref[r : r + 1], dim=-1)
|
||||
ref_val = ref_lp[0, idx]
|
||||
ref_val.backward(retain_graph=True)
|
||||
|
||||
for r in valid_rows:
|
||||
torch.testing.assert_close(
|
||||
logits.grad[r], logits_ref.grad[r], atol=1e-4, rtol=1e-4
|
||||
)
|
||||
Reference in New Issue
Block a user