285 lines
9.1 KiB
Python
285 lines
9.1 KiB
Python
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
|
|
|
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
|
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
|
and full fwd+bwd autograd throughput.
|
|
|
|
Usage:
|
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
|
"""
|
|
|
|
import argparse
|
|
import gc
|
|
import time
|
|
from functools import partial
|
|
|
|
import torch
|
|
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
|
lora_ops,
|
|
ops as base_ops,
|
|
)
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
|
flatten_sort_count,
|
|
)
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
|
ScatterMoELoRA,
|
|
)
|
|
|
|
DEVICE = "cuda"
|
|
DTYPE = torch.bfloat16
|
|
WARMUP = 5
|
|
ITERS = 20
|
|
|
|
# ─── Model configs ──────────────────────────────────────────────────────────
|
|
|
|
BUILTIN_CONFIGS = {
|
|
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
|
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
|
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
|
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
|
}
|
|
|
|
|
|
def _resolve_config(spec):
|
|
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
|
key = spec.lower().replace("/", "-")
|
|
for name, cfg in BUILTIN_CONFIGS.items():
|
|
if key in name.lower() or name.lower() in key:
|
|
return name, cfg
|
|
|
|
from transformers import AutoConfig
|
|
|
|
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
|
if callable(getattr(hf_cfg, "get_text_config", None)):
|
|
tc = hf_cfg.get_text_config()
|
|
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
|
hf_cfg = tc
|
|
hidden = hf_cfg.hidden_size
|
|
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
|
experts = (
|
|
getattr(hf_cfg, "num_experts", None)
|
|
or getattr(hf_cfg, "num_local_experts", None)
|
|
or getattr(hf_cfg, "n_routed_experts", None)
|
|
)
|
|
top_k = (
|
|
getattr(hf_cfg, "num_experts_per_tok", None)
|
|
or getattr(hf_cfg, "num_experts_per_token", None)
|
|
or 2
|
|
)
|
|
name = spec.split("/")[-1]
|
|
return name, (experts, hidden, inter, top_k)
|
|
|
|
|
|
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
|
|
|
|
|
def _clean():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
|
for _ in range(warmup):
|
|
fn()
|
|
torch.cuda.synchronize()
|
|
times = []
|
|
for _ in range(iters):
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
fn()
|
|
torch.cuda.synchronize()
|
|
times.append((time.perf_counter() - t0) * 1000)
|
|
times.sort()
|
|
return times[len(times) // 2]
|
|
|
|
|
|
def _setup(num_experts, K, N, T, top_k, R):
|
|
torch.manual_seed(42)
|
|
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
|
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
|
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
|
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
|
logits = torch.randn(T, num_experts, device=DEVICE)
|
|
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
|
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
|
gx = base_ops.group(x, ssi, fan_out=top_k)
|
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
|
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
|
|
|
|
|
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
|
|
|
|
|
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
|
return lora_ops.scatter2scatter_lora(
|
|
X=x,
|
|
W=W,
|
|
sorted_expert_idxs=sei,
|
|
sorted_scattered_idxs=ssi,
|
|
k=top_k,
|
|
lora_A=lA,
|
|
lora_B=lB,
|
|
scaling=2.0,
|
|
)
|
|
|
|
|
|
def _call_base(x, W, sei, ssi, top_k):
|
|
return base_ops.scatter2scatter(
|
|
X=x,
|
|
W=W,
|
|
sorted_expert_idxs=sei,
|
|
sorted_scattered_idxs=ssi,
|
|
k=top_k,
|
|
)
|
|
|
|
|
|
def _call_dx(dy, W, sei, ssi, lA, lB):
|
|
return lora_ops.scatter2scatter_lora_dX(
|
|
DY=dy,
|
|
W=W,
|
|
sorted_expert_idxs=sei,
|
|
sorted_scattered_idxs=ssi,
|
|
k=1,
|
|
lora_A=lA,
|
|
lora_B=lB,
|
|
scaling=2.0,
|
|
dy_grouped=True,
|
|
dx_grouped=False,
|
|
)
|
|
|
|
|
|
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
|
return lora_ops.group_bwd_lora(
|
|
DY=dy,
|
|
X=gx,
|
|
lora_A=lA,
|
|
lora_B=lB,
|
|
expert_offsets=eo,
|
|
E=num_experts,
|
|
scaling=2.0,
|
|
)
|
|
|
|
|
|
# ─── Main ────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
|
parser.add_argument(
|
|
"--models",
|
|
"-m",
|
|
nargs="+",
|
|
help="Model names or HF IDs (default: all builtins)",
|
|
)
|
|
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
|
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
|
args = parser.parse_args()
|
|
|
|
T = args.seq_len
|
|
print(f"GPU: {torch.cuda.get_device_name()}")
|
|
print(f"T={T}, ranks={args.ranks}\n")
|
|
|
|
if args.models:
|
|
configs = [_resolve_config(m) for m in args.models]
|
|
else:
|
|
configs = list(BUILTIN_CONFIGS.items())
|
|
|
|
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
|
print(f"{'=' * 70}")
|
|
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
|
print(f"{'=' * 70}")
|
|
|
|
for R in args.ranks:
|
|
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
|
_clean()
|
|
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
|
num_experts, K, N, T, top_k, R
|
|
)
|
|
|
|
# Forward with LoRA (auto-dispatched: fused or split)
|
|
dispatch = (
|
|
"split"
|
|
if (
|
|
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
|
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
|
)
|
|
else "fused"
|
|
)
|
|
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
|
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
|
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
|
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
|
|
|
total = t_fwd + t_dx + t_bwd
|
|
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
|
|
|
print(
|
|
f" R={R:>2} {proj:<8} "
|
|
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
|
f"base={t_base:>6.2f}ms "
|
|
f"(+{overhead * 100:.0f}%) "
|
|
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
|
f"total={total:>6.2f}ms"
|
|
)
|
|
|
|
# Full autograd fwd+bwd with memory measurement
|
|
x_ag = x.clone().requires_grad_(True)
|
|
lA_ag = lA.clone().requires_grad_(True)
|
|
lB_ag = lB.clone().requires_grad_(True)
|
|
|
|
def _run_autograd(
|
|
_x=x_ag,
|
|
_W=W,
|
|
_k=top_k,
|
|
_sei=sei,
|
|
_ssi=ssi,
|
|
_eo=eo,
|
|
_lA=lA_ag,
|
|
_lB=lB_ag,
|
|
):
|
|
out = ScatterMoELoRA.apply(
|
|
_x,
|
|
_W,
|
|
_k,
|
|
_sei,
|
|
_ssi,
|
|
_eo,
|
|
_lA,
|
|
_lB,
|
|
2.0,
|
|
None,
|
|
None,
|
|
False,
|
|
False,
|
|
True,
|
|
False,
|
|
)
|
|
out.sum().backward()
|
|
_x.grad = None
|
|
_lA.grad = None
|
|
_lB.grad = None
|
|
|
|
t_full = _bench(_run_autograd)
|
|
|
|
_clean()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
mem_before = torch.cuda.memory_allocated()
|
|
_run_autograd()
|
|
torch.cuda.synchronize()
|
|
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
|
|
|
print(
|
|
f" full_fwd_bwd={t_full:>6.2f}ms "
|
|
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
|
)
|
|
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|