add correctness unit tests and benchmarks for scattermoe + lora

This commit is contained in:
Wing Lian
2026-03-19 06:40:01 +00:00
parent 07ff389be8
commit 66fea258c7
2 changed files with 504 additions and 0 deletions

View File

@@ -0,0 +1,193 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import statistics
import time
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
ops as base_ops,
lora_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
# Try HuggingFace AutoConfig
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
H = hf_cfg.hidden_size
I = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
E = (getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None))
k = (getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None) or 2)
name = spec.split("/")[-1]
return name, (E, H, I, k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
return statistics.median(times)
def _setup(E, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, E, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, E)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument("--models", "-m", nargs="+",
help="Model names or HF IDs (default: all builtins)")
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
configs = [(n, c) for n, c in configs]
for model_name, (E, H, I, k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={E}, H={H}, I={I}, k={k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", H, 2 * I), ("down", I, H)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
# Forward with LoRA
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
))
# Forward without LoRA (base)
t_base = _bench(lambda: base_ops.scatter2scatter(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
))
# Backward dX
t_dx = _bench(lambda: lora_ops.scatter2scatter_lora_dX(
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=1, lora_A=lA, lora_B=lB, scaling=2.0,
dy_grouped=True, dx_grouped=False,
))
# Backward dA/dB
t_bwd = _bench(lambda: lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=2.0,
))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms "
f"(+{overhead*100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms")
# Full autograd fwd+bwd
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd():
out = ScatterMoELoRA.apply(
x_ag, W, k, sei, ssi, eo,
lA_ag, lB_ag, 2.0,
None, None, False, False, True, False,
)
out.sum().backward()
x_ag.grad = None
lA_ag.grad = None
lB_ag.grad = None
t_full = _bench(_run_autograd)
print(f" full_fwd_bwd={t_full:>6.2f}ms")
print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Unit tests for ScatterMoE LoRA Triton kernels.
Tests correctness of:
- scatter2scatter_lora (forward)
- scatter2scatter_lora_dX (backward input gradient)
- group_bwd_lora (backward LoRA weight gradients via split dA/dB)
- ScatterMoELoRA autograd function (full forward + backward)
Each kernel is tested against a pure PyTorch per-expert-loop reference
implementation at multiple model shapes and LoRA ranks.
"""
import pytest
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
ops as base_ops,
lora_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
def _requires_cuda():
return pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
)
pytestmark = _requires_cuda()
# ─── Helpers ─────────────────────────────────────────────────────────────────
def _setup(E, K, N, T, top_k, R, seed=42):
"""Create synthetic expert weights, LoRA, routing, and grouped inputs."""
torch.manual_seed(seed)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, E, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, E)
return x, W, lora_A, lora_B, sei, ssi, eo
def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):
"""Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T."""
grouped_x = base_ops.group(x, ssi, fan_out=k)
M, N = grouped_x.size(0), W.size(2)
R = lora_A.size(0) // E
out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
for e in range(E):
s = eo[e - 1].item() if e > 0 else 0
end = eo[e].item()
if s == end:
continue
xe = grouped_x[s:end].float()
we = W[e].float()
ae = lora_A[e * R : (e + 1) * R].float()
be = lora_B[:, e * R : (e + 1) * R].float()
out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)
result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
result[ssi] = out
return result
def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):
"""Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A."""
M, K = dy_grouped.size(0), W.size(1)
R = lora_A.size(0) // E
out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
for e in range(E):
s = eo[e - 1].item() if e > 0 else 0
end = eo[e].item()
if s == end:
continue
dye = dy_grouped[s:end].float()
we = W[e].float()
ae = lora_A[e * R : (e + 1) * R].float()
be = lora_B[:, e * R : (e + 1) * R].float()
out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)
result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
result[ssi] = out
return result
def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):
"""Per-expert loop reference: dA, dB for LoRA weight gradients."""
R = lora_A.size(0) // E
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
for e in range(E):
s = eo[e - 1].item() if e > 0 else 0
end = eo[e].item()
if s == end:
continue
xe = grouped_x[s:end].float()
dye = dy[s:end].float()
ae = lora_A[e * R : (e + 1) * R].float()
be = lora_B[:, e * R : (e + 1) * R].float()
dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)
dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)
return dA, dB
# ─── Model shape configs ────────────────────────────────────────────────────
# (E, K, N, T, top_k, R, description)
CONFIGS_SMALL = [
(32, 128, 64, 64, 2, 4, "tiny"),
(64, 256, 128, 128, 4, 8, "small"),
]
CONFIGS_REAL = [
(256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"),
(256, 512, 2048, 2048, 8, 16, "qwen35_down"),
(64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"),
(128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"),
]
SCALING = 2.0
# ─── Forward tests ──────────────────────────────────────────────────────────
class TestScatter2ScatterLoRAForward:
"""Test scatter2scatter_lora forward kernel vs reference."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
def config(self, request):
return request.param
def test_matches_reference(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
kernel_out = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
)
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
err = (kernel_out.float() - ref_out.float()).abs().max().item()
assert err < 1.0, f"[{desc}] fwd max_err={err}"
def test_output_shape(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
)
assert out.shape == (T * k, N)
assert out.dtype == DTYPE
# ─── Backward dX tests ──────────────────────────────────────────────────────
class TestScatter2ScatterLoRADX:
"""Test scatter2scatter_lora_dX backward kernel vs reference."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
def config(self, request):
return request.param
def test_matches_reference(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
gx = base_ops.group(x, ssi, fan_out=k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kernel_dx = lora_ops.scatter2scatter_lora_dX(
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=1, lora_A=lA, lora_B=lB, scaling=SCALING,
dy_grouped=True, dx_grouped=False,
)
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
err = (kernel_dx.float() - ref_dx.float()).abs().max().item()
assert err < 1.0, f"[{desc}] dX max_err={err}"
# ─── Backward LoRA gradient tests ───────────────────────────────────────────
class TestGroupBwdLoRA:
"""Test group_bwd_lora (split dA/dB kernel) vs reference."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
def config(self, request):
return request.param
def test_matches_reference(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
gx = base_ops.group(x, ssi, fan_out=k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kern_dA, kern_dB = lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=SCALING,
)
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
# Use norm-relative error: bf16 accumulation order differs between
# kernel (tiled + different reduction order) and reference (per-expert
# fp32 loop), so max absolute error can be large on individual elements
# while the overall tensor is correct.
dA_norm_err = (
(kern_dA.float() - ref_dA.float()).norm()
/ (ref_dA.float().norm() + 1e-6)
).item()
dB_norm_err = (
(kern_dB.float() - ref_dB.float()).norm()
/ (ref_dB.float().norm() + 1e-6)
).item()
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
def test_zero_expert_tokens(self):
"""Experts with zero routed tokens produce zero gradients."""
E, K, N, R = 8, 64, 32, 4
torch.manual_seed(42)
# Route all tokens to expert 0 only
T, k = 16, 1
top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)
sei, ssi, eo = flatten_sort_count(top_idx, E)
gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)
lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
dA, dB = lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=2.0,
)
# Experts 1..7 should have zero gradients
for e in range(1, E):
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero"
# ─── Full autograd tests ────────────────────────────────────────────────────
class TestScatterMoELoRAAutograd:
"""Test full forward + backward through ScatterMoELoRA autograd function."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])
def config(self, request):
return request.param
def test_gradients_exist_and_finite(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
x = x.requires_grad_(True)
lA = lA.requires_grad_(True)
lB = lB.requires_grad_(True)
out = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo,
lA, lB, SCALING,
None, None, False, False, True, False,
)
out.sum().backward()
assert x.grad is not None, f"[{desc}] x.grad is None"
assert lA.grad is not None, f"[{desc}] lA.grad is None"
assert lB.grad is not None, f"[{desc}] lB.grad is None"
assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite"
assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite"
assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite"
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
def test_scaling_zero_gives_base_only(self):
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out_lora = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo,
lA, lB, 0.0,
None, None, False, False, True, False,
)
out_base = base_ops.scatter2scatter(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
)
err = (out_lora.float() - out_base.float()).abs().max().item()
assert err < 0.01, f"scaling=0 should match base: err={err}"