From fb11f696e986747d96156406ba4cea8223c621e3 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 19 Sep 2025 13:24:40 -0400 Subject: [PATCH] bench sweep --- scripts/bench_moe_sweep.py | 272 +++++++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 scripts/bench_moe_sweep.py diff --git a/scripts/bench_moe_sweep.py b/scripts/bench_moe_sweep.py new file mode 100644 index 000000000..d90b37fb4 --- /dev/null +++ b/scripts/bench_moe_sweep.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python +"""Sweep grouped_mm vs naive performance for Qwen2 MoE block.""" + +from __future__ import annotations + +import argparse +import csv +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import torch + +try: + from axolotl.kernels.moe import torch_grouped as tg +except Exception: # pragma: no cover + tg = None + + +def _parse_list(arg: str) -> List[int]: + return [int(v) for v in arg.split(",") if v] + + +def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float: + for _ in range(warmup): + run() + if device.type == "cuda": + torch.cuda.synchronize() + times: List[float] = [] + for _ in range(iters): + if device.type == "cuda": + torch.cuda.synchronize() + start = time.perf_counter() + run() + if device.type == "cuda": + torch.cuda.synchronize() + times.append((time.perf_counter() - start) * 1000.0) + return sum(times) / len(times) + + +def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float: + return 6.0 * tokens * top_k * hidden * inter + + +def _load_block( + hidden: int, + inter: int, + experts: int, + top_k: int, + *, + device: torch.device, + dtype: torch.dtype, +): + project_root = Path(__file__).resolve().parents[2] + transformers_src = project_root / "transformers" / "src" + if transformers_src.exists() and str(transformers_src) not in sys.path: + sys.path.append(str(transformers_src)) + + from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + cfg = Qwen2MoeConfig( + hidden_size=hidden, + moe_intermediate_size=inter, + shared_expert_intermediate_size=inter, + num_experts=experts, + num_experts_per_tok=top_k, + norm_topk_prob=True, + qkv_bias=True, + ) + + block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) + block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) + block_grouped.load_state_dict(block.state_dict()) + return block, block_grouped + + +@dataclass +class Result: + bsz: int + seq: int + hidden: int + inter: int + experts: int + top_k: int + dtype: str + naive_ms: float + grouped_ms: float + speedup: float + naive_tflops: float + grouped_tflops: float + max_abs: float + mean_abs: float + rel_l2: float + + +def main() -> None: + p = argparse.ArgumentParser(description="Grouped MoE sweep") + p.add_argument("--batch-sizes", default="4,8,16") + p.add_argument("--seq-lens", default="512,1024,2048") + p.add_argument("--hidden", default="2048,4096") + p.add_argument("--inter", default="5632,8192,14336") + p.add_argument("--experts", default="8,16,32") + p.add_argument("--top-k", default="1,2,4") + p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") + p.add_argument("--iters", type=int, default=25) + p.add_argument("--warmup", type=int, default=5) + p.add_argument("--csv", type=Path, default=None) + args = p.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + }[args.dtype] + + if tg is None or not tg.available(): + print("torch_grouped unavailable; sweep aborted") + return + + bs_list = _parse_list(args.batch_sizes) + seq_list = _parse_list(args.seq_lens) + hidden_list = _parse_list(args.hidden) + inter_list = _parse_list(args.inter) + expert_list = _parse_list(args.experts) + topk_list = _parse_list(args.top_k) + + results: List[Result] = [] + + print( + "bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t" + "naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2" + ) + + for bsz in bs_list: + for seq in seq_list: + tokens = bsz * seq + for hidden in hidden_list: + for inter in inter_list: + for experts in expert_list: + for top_k in topk_list: + torch.manual_seed(0) + if device.type == "cuda": + torch.cuda.manual_seed(0) + + block_naive, block_grouped = _load_block( + hidden, + inter, + experts, + top_k, + device=device, + dtype=dtype, + ) + + x = torch.randn( + bsz, seq, hidden, device=device, dtype=dtype + ) + + def run_naive(block=block_naive, data=x): + y, _ = block(data) + return y + + def run_grouped(block=block_grouped, data=x): + block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore + y, _ = tg.moe_ffn_forward_grouped( + data, + block.gate, + block.experts, + block.top_k, + ) + return y + + naive_ms = _bench( + run_naive, + iters=args.iters, + warmup=args.warmup, + device=device, + ) + y_naive = run_naive() + + grouped_ms = _bench( + run_grouped, + iters=args.iters, + warmup=args.warmup, + device=device, + ) + y_grouped = run_grouped() + + diff = (y_naive.float() - y_grouped.float()).abs() + res = Result( + bsz, + seq, + hidden, + inter, + experts, + top_k, + args.dtype, + naive_ms, + grouped_ms, + naive_ms / grouped_ms, + _estimate_flops(tokens, hidden, inter, top_k) + / ((naive_ms / 1000.0) * 1e12), + _estimate_flops(tokens, hidden, inter, top_k) + / ((grouped_ms / 1000.0) * 1e12), + diff.max().item(), + diff.mean().item(), + ( + ( + diff.pow(2).sum() + / (y_naive.float().pow(2).sum() + 1e-12) + ) + .sqrt() + .item() + ), + ) + results.append(res) + print( + f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t" + f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t" + f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}" + ) + + if args.csv: + fieldnames = [ + "bsz", + "seq", + "hidden", + "inter", + "experts", + "top_k", + "dtype", + "naive_ms", + "grouped_ms", + "speedup", + "naive_tflops", + "grouped_tflops", + "max_abs", + "mean_abs", + "rel_l2", + ] + with args.csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for r in results: + writer.writerow( + { + "bsz": r.bsz, + "seq": r.seq, + "hidden": r.hidden, + "inter": r.inter, + "experts": r.experts, + "top_k": r.top_k, + "dtype": r.dtype, + "naive_ms": f"{r.naive_ms:.4f}", + "grouped_ms": f"{r.grouped_ms:.4f}", + "speedup": f"{r.speedup:.4f}", + "naive_tflops": f"{r.naive_tflops:.4f}", + "grouped_tflops": f"{r.grouped_tflops:.4f}", + "max_abs": f"{r.max_abs:.6e}", + "mean_abs": f"{r.mean_abs:.6e}", + "rel_l2": f"{r.rel_l2:.6e}", + } + ) + + +if __name__ == "__main__": + import weakref + + main()