From 66fea258c7f8a4d9bb6c8bb348d07f4234f41a6d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 19 Mar 2026 06:40:01 +0000 Subject: [PATCH] add correctness unit tests and benchmarks for scattermoe + lora --- benchmarks/bench_scattermoe_lora.py | 193 +++++++++++ .../test_scattermoe_lora_kernels.py | 311 ++++++++++++++++++ 2 files changed, 504 insertions(+) create mode 100644 benchmarks/bench_scattermoe_lora.py create mode 100644 tests/integrations/test_scattermoe_lora_kernels.py diff --git a/benchmarks/bench_scattermoe_lora.py b/benchmarks/bench_scattermoe_lora.py new file mode 100644 index 000000000..3b995c1ff --- /dev/null +++ b/benchmarks/bench_scattermoe_lora.py @@ -0,0 +1,193 @@ +"""Benchmark for ScatterMoE LoRA Triton kernels. + +Measures forward, backward dX, and backward dA/dB kernels at common MoE +model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter, +and full fwd+bwd autograd throughput. + +Usage: + CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py + CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64 + CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B +""" + +import argparse +import gc +import statistics +import time + +import torch + +from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import ( + ops as base_ops, + lora_ops, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import ( + ScatterMoELoRA, +) + +DEVICE = "cuda" +DTYPE = torch.bfloat16 +WARMUP = 5 +ITERS = 20 + +# ─── Model configs ────────────────────────────────────────────────────────── + +BUILTIN_CONFIGS = { + "Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k + "Qwen3-30B-A3B": (128, 2048, 768, 8), + "OLMoE-1B-7B": (64, 2048, 1024, 8), + "Mixtral-8x7B": (8, 4096, 14336, 2), +} + + +def _resolve_config(spec): + """Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs.""" + key = spec.lower().replace("/", "-") + for name, cfg in BUILTIN_CONFIGS.items(): + if key in name.lower() or name.lower() in key: + return name, cfg + + # Try HuggingFace AutoConfig + from transformers import AutoConfig + hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True) + if callable(getattr(hf_cfg, "get_text_config", None)): + tc = hf_cfg.get_text_config() + if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type: + hf_cfg = tc + H = hf_cfg.hidden_size + I = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size + E = (getattr(hf_cfg, "num_experts", None) + or getattr(hf_cfg, "num_local_experts", None) + or getattr(hf_cfg, "n_routed_experts", None)) + k = (getattr(hf_cfg, "num_experts_per_tok", None) + or getattr(hf_cfg, "num_experts_per_token", None) or 2) + name = spec.split("/")[-1] + return name, (E, H, I, k) + + +# ─── Benchmark helpers ────────────────────────────────────────────────────── + +def _clean(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def _bench(fn, warmup=WARMUP, iters=ITERS): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + times = [] + for _ in range(iters): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + torch.cuda.synchronize() + times.append((time.perf_counter() - t0) * 1000) + return statistics.median(times) + + +def _setup(E, K, N, T, top_k, R): + torch.manual_seed(42) + x = torch.randn(T, K, device=DEVICE, dtype=DTYPE) + W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02 + lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01 + lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01 + logits = torch.randn(T, E, device=DEVICE) + _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1) + sei, ssi, eo = flatten_sort_count(top_idx, E) + gx = base_ops.group(x, ssi, fan_out=top_k) + dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) + return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy + + +# ─── Main ──────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark") + parser.add_argument("--models", "-m", nargs="+", + help="Model names or HF IDs (default: all builtins)") + parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64]) + parser.add_argument("--seq-len", "-T", type=int, default=2048) + args = parser.parse_args() + + T = args.seq_len + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"T={T}, ranks={args.ranks}\n") + + if args.models: + configs = [_resolve_config(m) for m in args.models] + else: + configs = list(BUILTIN_CONFIGS.items()) + configs = [(n, c) for n, c in configs] + + for model_name, (E, H, I, k) in configs: + print(f"{'=' * 70}") + print(f" {model_name}: E={E}, H={H}, I={I}, k={k}") + print(f"{'=' * 70}") + + for R in args.ranks: + for proj, K, N in [("gate_up", H, 2 * I), ("down", I, H)]: + _clean() + x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R) + + # Forward with LoRA + t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=k, lora_A=lA, lora_B=lB, scaling=2.0, + )) + + # Forward without LoRA (base) + t_base = _bench(lambda: base_ops.scatter2scatter( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k, + )) + + # Backward dX + t_dx = _bench(lambda: lora_ops.scatter2scatter_lora_dX( + DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=1, lora_A=lA, lora_B=lB, scaling=2.0, + dy_grouped=True, dx_grouped=False, + )) + + # Backward dA/dB + t_bwd = _bench(lambda: lora_ops.group_bwd_lora( + DY=dy, X=gx, lora_A=lA, lora_B=lB, + expert_offsets=eo, E=E, scaling=2.0, + )) + + total = t_fwd + t_dx + t_bwd + overhead = t_fwd / t_base - 1 if t_base > 0 else 0 + + print(f" R={R:>2} {proj:<8} " + f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms " + f"(+{overhead*100:.0f}%) " + f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms " + f"total={total:>6.2f}ms") + + # Full autograd fwd+bwd + x_ag = x.clone().requires_grad_(True) + lA_ag = lA.clone().requires_grad_(True) + lB_ag = lB.clone().requires_grad_(True) + + def _run_autograd(): + out = ScatterMoELoRA.apply( + x_ag, W, k, sei, ssi, eo, + lA_ag, lB_ag, 2.0, + None, None, False, False, True, False, + ) + out.sum().backward() + x_ag.grad = None + lA_ag.grad = None + lB_ag.grad = None + + t_full = _bench(_run_autograd) + print(f" full_fwd_bwd={t_full:>6.2f}ms") + + print() + + +if __name__ == "__main__": + main() diff --git a/tests/integrations/test_scattermoe_lora_kernels.py b/tests/integrations/test_scattermoe_lora_kernels.py new file mode 100644 index 000000000..fa6dc72f5 --- /dev/null +++ b/tests/integrations/test_scattermoe_lora_kernels.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Unit tests for ScatterMoE LoRA Triton kernels. + +Tests correctness of: + - scatter2scatter_lora (forward) + - scatter2scatter_lora_dX (backward input gradient) + - group_bwd_lora (backward LoRA weight gradients via split dA/dB) + - ScatterMoELoRA autograd function (full forward + backward) + +Each kernel is tested against a pure PyTorch per-expert-loop reference +implementation at multiple model shapes and LoRA ranks. +""" + +import pytest +import torch + +from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import ( + ops as base_ops, + lora_ops, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import ( + ScatterMoELoRA, +) + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +def _requires_cuda(): + return pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ) + + +pytestmark = _requires_cuda() + + +# ─── Helpers ───────────────────────────────────────────────────────────────── + + +def _setup(E, K, N, T, top_k, R, seed=42): + """Create synthetic expert weights, LoRA, routing, and grouped inputs.""" + torch.manual_seed(seed) + x = torch.randn(T, K, device=DEVICE, dtype=DTYPE) + W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02 + lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01 + lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01 + logits = torch.randn(T, E, device=DEVICE) + _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1) + sei, ssi, eo = flatten_sort_count(top_idx, E) + return x, W, lora_A, lora_B, sei, ssi, eo + + +def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E): + """Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T.""" + grouped_x = base_ops.group(x, ssi, fan_out=k) + M, N = grouped_x.size(0), W.size(2) + R = lora_A.size(0) // E + out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE) + for e in range(E): + s = eo[e - 1].item() if e > 0 else 0 + end = eo[e].item() + if s == end: + continue + xe = grouped_x[s:end].float() + we = W[e].float() + ae = lora_A[e * R : (e + 1) * R].float() + be = lora_B[:, e * R : (e + 1) * R].float() + out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE) + result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE) + result[ssi] = out + return result + + +def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E): + """Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A.""" + M, K = dy_grouped.size(0), W.size(1) + R = lora_A.size(0) // E + out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE) + for e in range(E): + s = eo[e - 1].item() if e > 0 else 0 + end = eo[e].item() + if s == end: + continue + dye = dy_grouped[s:end].float() + we = W[e].float() + ae = lora_A[e * R : (e + 1) * R].float() + be = lora_B[:, e * R : (e + 1) * R].float() + out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE) + result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE) + result[ssi] = out + return result + + +def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling): + """Per-expert loop reference: dA, dB for LoRA weight gradients.""" + R = lora_A.size(0) // E + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + for e in range(E): + s = eo[e - 1].item() if e > 0 else 0 + end = eo[e].item() + if s == end: + continue + xe = grouped_x[s:end].float() + dye = dy[s:end].float() + ae = lora_A[e * R : (e + 1) * R].float() + be = lora_B[:, e * R : (e + 1) * R].float() + dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE) + dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE) + return dA, dB + + +# ─── Model shape configs ──────────────────────────────────────────────────── + +# (E, K, N, T, top_k, R, description) +CONFIGS_SMALL = [ + (32, 128, 64, 64, 2, 4, "tiny"), + (64, 256, 128, 128, 4, 8, "small"), +] + +CONFIGS_REAL = [ + (256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"), + (256, 512, 2048, 2048, 8, 16, "qwen35_down"), + (64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"), + (128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"), +] + +SCALING = 2.0 + + +# ─── Forward tests ────────────────────────────────────────────────────────── + + +class TestScatter2ScatterLoRAForward: + """Test scatter2scatter_lora forward kernel vs reference.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL) + def config(self, request): + return request.param + + def test_matches_reference(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + kernel_out = lora_ops.scatter2scatter_lora( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + ) + ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E) + + err = (kernel_out.float() - ref_out.float()).abs().max().item() + assert err < 1.0, f"[{desc}] fwd max_err={err}" + + def test_output_shape(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + out = lora_ops.scatter2scatter_lora( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + ) + assert out.shape == (T * k, N) + assert out.dtype == DTYPE + + +# ─── Backward dX tests ────────────────────────────────────────────────────── + + +class TestScatter2ScatterLoRADX: + """Test scatter2scatter_lora_dX backward kernel vs reference.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL) + def config(self, request): + return request.param + + def test_matches_reference(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + gx = base_ops.group(x, ssi, fan_out=k) + dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) + + kernel_dx = lora_ops.scatter2scatter_lora_dX( + DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=1, lora_A=lA, lora_B=lB, scaling=SCALING, + dy_grouped=True, dx_grouped=False, + ) + ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E) + + err = (kernel_dx.float() - ref_dx.float()).abs().max().item() + assert err < 1.0, f"[{desc}] dX max_err={err}" + + +# ─── Backward LoRA gradient tests ─────────────────────────────────────────── + + +class TestGroupBwdLoRA: + """Test group_bwd_lora (split dA/dB kernel) vs reference.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL) + def config(self, request): + return request.param + + def test_matches_reference(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + gx = base_ops.group(x, ssi, fan_out=k) + dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) + + kern_dA, kern_dB = lora_ops.group_bwd_lora( + DY=dy, X=gx, lora_A=lA, lora_B=lB, + expert_offsets=eo, E=E, scaling=SCALING, + ) + ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING) + + # Use norm-relative error: bf16 accumulation order differs between + # kernel (tiled + different reduction order) and reference (per-expert + # fp32 loop), so max absolute error can be large on individual elements + # while the overall tensor is correct. + dA_norm_err = ( + (kern_dA.float() - ref_dA.float()).norm() + / (ref_dA.float().norm() + 1e-6) + ).item() + dB_norm_err = ( + (kern_dB.float() - ref_dB.float()).norm() + / (ref_dB.float().norm() + 1e-6) + ).item() + assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}" + assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}" + + def test_zero_expert_tokens(self): + """Experts with zero routed tokens produce zero gradients.""" + E, K, N, R = 8, 64, 32, 4 + torch.manual_seed(42) + # Route all tokens to expert 0 only + T, k = 16, 1 + top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE) + sei, ssi, eo = flatten_sort_count(top_idx, E) + gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE) + dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE) + lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) + lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) + + dA, dB = lora_ops.group_bwd_lora( + DY=dy, X=gx, lora_A=lA, lora_B=lB, + expert_offsets=eo, E=E, scaling=2.0, + ) + + # Experts 1..7 should have zero gradients + for e in range(1, E): + assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero" + assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero" + + +# ─── Full autograd tests ──────────────────────────────────────────────────── + + +class TestScatterMoELoRAAutograd: + """Test full forward + backward through ScatterMoELoRA autograd function.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2]) + def config(self, request): + return request.param + + def test_gradients_exist_and_finite(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + x = x.requires_grad_(True) + lA = lA.requires_grad_(True) + lB = lB.requires_grad_(True) + + out = ScatterMoELoRA.apply( + x, W, k, sei, ssi, eo, + lA, lB, SCALING, + None, None, False, False, True, False, + ) + out.sum().backward() + + assert x.grad is not None, f"[{desc}] x.grad is None" + assert lA.grad is not None, f"[{desc}] lA.grad is None" + assert lB.grad is not None, f"[{desc}] lB.grad is None" + assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite" + assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite" + assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite" + assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero" + assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero" + + + def test_scaling_zero_gives_base_only(self): + """With scaling=0.0, LoRA contribution vanishes. Output = X@W.""" + E, K, N, T, k, R = 16, 64, 32, 32, 2, 4 + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + out_lora = ScatterMoELoRA.apply( + x, W, k, sei, ssi, eo, + lA, lB, 0.0, + None, None, False, False, True, False, + ) + out_base = base_ops.scatter2scatter( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k, + ) + err = (out_lora.float() - out_base.float()).abs().max().item() + assert err < 0.01, f"scaling=0 should match base: err={err}"