#!/usr/bin/env python """Sweep grouped_mm vs naive performance for Qwen2 MoE block.""" from __future__ import annotations import argparse import csv import sys import time import weakref from dataclasses import dataclass from pathlib import Path from typing import List import torch try: from axolotl.kernels.moe import torch_grouped as tg except Exception: # pragma: no cover tg = None def _parse_list(arg: str) -> List[int]: return [int(v) for v in arg.split(",") if v] def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float: for _ in range(warmup): run() if device.type == "cuda": torch.cuda.synchronize() times: List[float] = [] for _ in range(iters): if device.type == "cuda": torch.cuda.synchronize() start = time.perf_counter() run() if device.type == "cuda": torch.cuda.synchronize() times.append((time.perf_counter() - start) * 1000.0) return sum(times) / len(times) def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float: return 6.0 * tokens * top_k * hidden * inter def _load_block( hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype, ): project_root = Path(__file__).resolve().parents[2] transformers_src = project_root / "transformers" / "src" if transformers_src.exists() and str(transformers_src) not in sys.path: sys.path.append(str(transformers_src)) from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock cfg = Qwen2MoeConfig( hidden_size=hidden, moe_intermediate_size=inter, shared_expert_intermediate_size=inter, num_experts=experts, num_experts_per_tok=top_k, norm_topk_prob=True, qkv_bias=True, ) block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) block_grouped.load_state_dict(block.state_dict()) return block, block_grouped @dataclass class Result: bsz: int seq: int hidden: int inter: int experts: int top_k: int dtype: str naive_ms: float grouped_ms: float speedup: float naive_tflops: float grouped_tflops: float max_abs: float mean_abs: float rel_l2: float def main() -> None: p = argparse.ArgumentParser(description="Grouped MoE sweep") p.add_argument("--batch-sizes", default="4,8,16") p.add_argument("--seq-lens", default="512,1024,2048") p.add_argument("--hidden", default="2048,4096") p.add_argument("--inter", default="5632,8192,14336") p.add_argument("--experts", default="8,16,32") p.add_argument("--top-k", default="1,2,4") p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") p.add_argument("--iters", type=int, default=25) p.add_argument("--warmup", type=int, default=5) p.add_argument("--csv", type=Path, default=None) p.add_argument("--compile", action="store_true") args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, }[args.dtype] if tg is None or not tg.available(): print("torch_grouped unavailable; sweep aborted") return bs_list = _parse_list(args.batch_sizes) seq_list = _parse_list(args.seq_lens) hidden_list = _parse_list(args.hidden) inter_list = _parse_list(args.inter) expert_list = _parse_list(args.experts) topk_list = _parse_list(args.top_k) results: List[Result] = [] print( "bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t" "naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2" ) for bsz in bs_list: for seq in seq_list: tokens = bsz * seq for hidden in hidden_list: for inter in inter_list: for experts in expert_list: for top_k in topk_list: torch.manual_seed(0) if device.type == "cuda": torch.cuda.manual_seed(0) block_naive, block_grouped = _load_block( hidden, inter, experts, top_k, device=device, dtype=dtype, ) x = torch.randn( bsz, seq, hidden, device=device, dtype=dtype ) compiled_impl = None if args.compile: try: block_naive = torch.compile(block_naive) # type: ignore[arg-type] except Exception as exc: print( f"torch.compile naive failed ({exc}); using eager" ) else: def grouped_forward(inp, *, block=block_grouped): block.experts._ax_parent_block_ref = ( weakref.ref(block) ) # type: ignore[attr-defined] y, _ = tg.moe_ffn_forward_grouped( inp, block.gate, block.experts, block.top_k, ) return y try: compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type] except Exception as exc: print( f"torch.compile grouped failed ({exc}); using eager" ) compiled_impl = None def run_naive(block=block_naive, data=x): y, _ = block(data) return y def run_grouped( block=block_grouped, data=x, impl=compiled_impl ): if impl is not None: return impl(data) block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined] y, _ = tg.moe_ffn_forward_grouped( data, block.gate, block.experts, block.top_k, ) return y naive_ms = _bench( run_naive, iters=args.iters, warmup=args.warmup, device=device, ) y_naive = run_naive() grouped_ms = _bench( run_grouped, iters=args.iters, warmup=args.warmup, device=device, ) y_grouped = run_grouped() diff = (y_naive.float() - y_grouped.float()).abs() res = Result( bsz, seq, hidden, inter, experts, top_k, args.dtype, naive_ms, grouped_ms, naive_ms / grouped_ms, _estimate_flops(tokens, hidden, inter, top_k) / ((naive_ms / 1000.0) * 1e12), _estimate_flops(tokens, hidden, inter, top_k) / ((grouped_ms / 1000.0) * 1e12), diff.max().item(), diff.mean().item(), ( ( diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12) ) .sqrt() .item() ), ) results.append(res) print( f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t" f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t" f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}" ) if args.csv: fieldnames = [ "bsz", "seq", "hidden", "inter", "experts", "top_k", "dtype", "naive_ms", "grouped_ms", "speedup", "naive_tflops", "grouped_tflops", "max_abs", "mean_abs", "rel_l2", ] with args.csv.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for r in results: writer.writerow( { "bsz": r.bsz, "seq": r.seq, "hidden": r.hidden, "inter": r.inter, "experts": r.experts, "top_k": r.top_k, "dtype": r.dtype, "naive_ms": f"{r.naive_ms:.4f}", "grouped_ms": f"{r.grouped_ms:.4f}", "speedup": f"{r.speedup:.4f}", "naive_tflops": f"{r.naive_tflops:.4f}", "grouped_tflops": f"{r.grouped_tflops:.4f}", "max_abs": f"{r.max_abs:.6e}", "mean_abs": f"{r.mean_abs:.6e}", "rel_l2": f"{r.rel_l2:.6e}", } ) if __name__ == "__main__": import weakref main()