* use custom triton kernels for entropy from logits and selective softmax * PR comments fixes * fix out of bounds, include tests, include benchmarks * chore: lint
192 lines
5.5 KiB
Python
192 lines
5.5 KiB
Python
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
|
|
|
|
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
|
|
"""
|
|
|
|
import gc
|
|
import statistics
|
|
|
|
import torch
|
|
|
|
from axolotl.monkeypatch.trainer.utils import (
|
|
selective_log_softmax,
|
|
selective_log_softmax_original,
|
|
)
|
|
|
|
V = 151936 # Qwen vocab
|
|
WARMUP = 5
|
|
BENCH_ITERS = 20
|
|
MEM_ITERS = 10
|
|
|
|
|
|
def _clean_gpu():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.reset_accumulated_memory_stats()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def profile_time(fn, args, n_iters=BENCH_ITERS):
|
|
for _ in range(WARMUP):
|
|
fn(*args)
|
|
torch.cuda.synchronize()
|
|
|
|
times = []
|
|
for _ in range(n_iters):
|
|
s = torch.cuda.Event(enable_timing=True)
|
|
e = torch.cuda.Event(enable_timing=True)
|
|
s.record()
|
|
fn(*args)
|
|
e.record()
|
|
torch.cuda.synchronize()
|
|
times.append(s.elapsed_time(e))
|
|
return times
|
|
|
|
|
|
def profile_memory(fn, args, n_iters=MEM_ITERS):
|
|
for _ in range(WARMUP):
|
|
out = fn(*args)
|
|
del out
|
|
torch.cuda.synchronize()
|
|
|
|
peaks = []
|
|
for _ in range(n_iters):
|
|
_clean_gpu()
|
|
base = torch.cuda.max_memory_allocated()
|
|
out = fn(*args)
|
|
torch.cuda.synchronize()
|
|
peaks.append(torch.cuda.max_memory_allocated() - base)
|
|
del out
|
|
return [p / 1e6 for p in peaks]
|
|
|
|
|
|
def fmt(values, unit=""):
|
|
mean = statistics.mean(values)
|
|
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
|
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
|
|
|
|
|
def benchmark_forward():
|
|
print("=" * 60)
|
|
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
|
print("=" * 60)
|
|
|
|
configs = [
|
|
(1, 2048),
|
|
(1, 8192),
|
|
(4, 4096),
|
|
(8, 2048),
|
|
(16, 2048),
|
|
(16, 4096),
|
|
]
|
|
|
|
for B, L in configs:
|
|
mem_gb = B * L * V * 2 / 1e9
|
|
if mem_gb > 28:
|
|
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
|
continue
|
|
|
|
N = B * L
|
|
print(f"\n{'─' * 60}")
|
|
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
|
print(f"{'─' * 60}")
|
|
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
|
index = torch.randint(0, V, (B, L), device="cuda")
|
|
|
|
t_orig = profile_time(selective_log_softmax_original, (logits, index))
|
|
t_triton = profile_time(selective_log_softmax, (logits, index))
|
|
orig_mean = statistics.mean(t_orig)
|
|
triton_mean = statistics.mean(t_triton)
|
|
|
|
print(" TIME (ms):")
|
|
print(f" original: {fmt(t_orig, 'ms')}")
|
|
print(f" triton: {fmt(t_triton, 'ms')}")
|
|
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
|
|
|
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
|
|
m_triton = profile_memory(selective_log_softmax, (logits, index))
|
|
orig_peak = statistics.mean(m_orig)
|
|
triton_peak = statistics.mean(m_triton)
|
|
|
|
print(" MEMORY (peak overhead):")
|
|
print(f" original: {fmt(m_orig, 'MB')}")
|
|
print(f" triton: {fmt(m_triton, 'MB')}")
|
|
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
|
|
|
del logits, index
|
|
_clean_gpu()
|
|
|
|
|
|
def benchmark_backward():
|
|
print("\n" + "=" * 60)
|
|
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
|
print("=" * 60)
|
|
|
|
configs = [
|
|
(1, 2048),
|
|
(1, 8192),
|
|
(4, 4096),
|
|
(8, 2048),
|
|
(16, 2048),
|
|
(16, 4096),
|
|
]
|
|
|
|
def fwd_bwd_original(logits, index):
|
|
logits.grad = None
|
|
out = selective_log_softmax_original(logits, index)
|
|
out.sum().backward()
|
|
|
|
def fwd_bwd_triton(logits, index):
|
|
logits.grad = None
|
|
out = selective_log_softmax(logits, index)
|
|
out.sum().backward()
|
|
|
|
for B, L in configs:
|
|
mem_gb = B * L * V * 2 / 1e9
|
|
if mem_gb > 20:
|
|
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
|
|
continue
|
|
|
|
N = B * L
|
|
print(f"\n{'─' * 60}")
|
|
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
|
print(f"{'─' * 60}")
|
|
|
|
torch.manual_seed(42)
|
|
logits_orig = torch.randn(
|
|
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
logits_tri = logits_orig.detach().clone().requires_grad_(True)
|
|
index = torch.randint(0, V, (B, L), device="cuda")
|
|
|
|
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
|
|
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
|
|
orig_mean = statistics.mean(t_orig)
|
|
triton_mean = statistics.mean(t_triton)
|
|
|
|
print(" FWD+BWD TIME (ms):")
|
|
print(f" original: {fmt(t_orig, 'ms')}")
|
|
print(f" triton: {fmt(t_triton, 'ms')}")
|
|
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
|
|
|
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
|
|
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
|
|
orig_peak = statistics.mean(m_orig)
|
|
triton_peak = statistics.mean(m_triton)
|
|
|
|
print(" FWD+BWD MEMORY (peak overhead):")
|
|
print(f" original: {fmt(m_orig, 'MB')}")
|
|
print(f" triton: {fmt(m_triton, 'MB')}")
|
|
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
|
|
|
del logits_orig, logits_tri, index
|
|
_clean_gpu()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark_forward()
|
|
benchmark_backward()
|