diff --git a/scripts/bench_torchtitan_moe_sweep.py b/scripts/bench_torchtitan_moe_sweep.py new file mode 100644 index 000000000..61c5c77aa --- /dev/null +++ b/scripts/bench_torchtitan_moe_sweep.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python +"""Sweep Torchtitan MoE grouped vs naive configurations and report performance.""" + +from __future__ import annotations + +import argparse +import csv +import math +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, List + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[2] +_TITAN_PATH = _PROJECT_ROOT / "torchtitan" +if str(_TITAN_PATH) not in sys.path: + sys.path.append(str(_TITAN_PATH)) + +from torchtitan.models.moe import MoE, MoEArgs + + +def _parse_int_list(value: str) -> List[int]: + return [int(v) for v in value.split(",") if v] + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep") + p.add_argument("--batch-sizes", default="4,8,16", help="Comma separated batch sizes") + p.add_argument("--seq-lens", default="1024,2048", help="Comma separated sequence lengths") + p.add_argument("--experts", default="8,16,32,64", help="Comma separated expert counts") + p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices") + p.add_argument("--hidden", type=int, default=4096) + p.add_argument("--inter", type=int, default=14336) + p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") + p.add_argument("--iters", type=int, default=25) + p.add_argument("--warmup", type=int, default=5) + p.add_argument("--init-std", type=float, default=0.02) + p.add_argument("--score-before", action="store_true") + p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax") + p.add_argument("--route-norm", action="store_true") + p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path") + return p.parse_args() + + +def _map_dtype(arg: str) -> torch.dtype: + return { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + }[arg] + + +def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float: + return 6.0 * tokens * top_k * hidden * inter + + +def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE: + module = module.to(device=device) + for param in module.parameters(): + param.data = param.data.to(dtype) + if param.grad is not None: + param.grad = None + for name, buf in module.named_buffers(): + if name == "tokens_per_expert": + module._buffers[name] = torch.zeros_like( + buf, dtype=torch.float32, device=device + ) + elif name == "expert_bias" and buf is not None: + module._buffers[name] = torch.zeros_like( + buf, dtype=torch.float32, device=device + ) + else: + module._buffers[name] = buf.to(device=device, dtype=dtype) + module.eval() + return module + + +@torch.inference_mode() +def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor: + return module(x) + + +def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float: + for _ in range(warmup): + callable_() + if device.type == "cuda": + torch.cuda.synchronize() + timings: List[float] = [] + for _ in range(iters): + if device.type == "cuda": + torch.cuda.synchronize() + start = time.perf_counter() + callable_() + if device.type == "cuda": + torch.cuda.synchronize() + timings.append((time.perf_counter() - start) * 1000.0) + return sum(timings) / len(timings) + + +@dataclass +class SweepResult: + bsz: int + seq: int + experts: int + top_k: int + dtype: str + naive_ms: float + grouped_ms: float + speedup: float + naive_tflops: float + grouped_tflops: float + max_abs: float + mean_abs: float + rel_l2: float + + +def _run_case( + *, + bsz: int, + seq: int, + experts: int, + top_k: int, + hidden: int, + inter: int, + dtype: torch.dtype, + device: torch.device, + iters: int, + warmup: int, + init_std: float, + score_before: bool, + score_func: str, + route_norm: bool, +) -> SweepResult: + torch.manual_seed(0) + if device.type == "cuda": + torch.cuda.manual_seed(0) + + moe_args_grouped = MoEArgs( + num_experts=experts, + num_shared_experts=0, + score_func=score_func, + route_norm=route_norm, + top_k=top_k, + use_grouped_mm=True, + score_before_experts=score_before, + load_balance_coeff=None, + ) + moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter) + moe_grouped.init_weights(init_std, buffer_device=device) + + moe_args_naive = MoEArgs( + num_experts=experts, + num_shared_experts=0, + score_func=score_func, + route_norm=route_norm, + top_k=top_k, + use_grouped_mm=False, + score_before_experts=score_before, + load_balance_coeff=None, + ) + moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter) + moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True) + + moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype) + moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype) + + x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype) + + def run_naive(): + if hasattr(moe_naive, "tokens_per_expert"): + moe_naive.tokens_per_expert.zero_() + return _forward(moe_naive, x) + + def run_grouped(): + if hasattr(moe_grouped, "tokens_per_expert"): + moe_grouped.tokens_per_expert.zero_() + return _forward(moe_grouped, x) + + naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device) + y_naive = run_naive() + + grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device) + y_grouped = run_grouped() + + diff = (y_naive.float() - y_grouped.float()).abs() + max_abs = diff.max().item() + mean_abs = diff.mean().item() + rel_l2 = ( + (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item() + ) + + tokens = bsz * seq + flops = _estimate_flops(tokens, hidden, inter, top_k) + naive_tflops = flops / ((naive_ms / 1000.0) * 1e12) + grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12) + speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan") + + return SweepResult( + bsz=bsz, + seq=seq, + experts=experts, + top_k=top_k, + dtype=str(dtype), + naive_ms=naive_ms, + grouped_ms=grouped_ms, + speedup=speedup, + naive_tflops=naive_tflops, + grouped_tflops=grouped_tflops, + max_abs=max_abs, + mean_abs=mean_abs, + rel_l2=rel_l2, + ) + + +def _print_header(hidden: int, inter: int, dtype: torch.dtype, device: torch.device) -> None: + print( + f"Device={device} dtype={dtype} hidden={hidden} inter={inter}" + ) + print( + "bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t" + "naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2" + ) + + +def _print_result(res: SweepResult) -> None: + print( + f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t" + f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t" + f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t" + f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}" + ) + + +def _write_csv(path: Path, results: Iterable[SweepResult]) -> None: + fieldnames = [ + "batch_size", + "seq_len", + "experts", + "top_k", + "dtype", + "naive_ms", + "grouped_ms", + "speedup", + "naive_tflops", + "grouped_tflops", + "max_abs", + "mean_abs", + "rel_l2", + ] + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for r in results: + writer.writerow( + { + "batch_size": r.bsz, + "seq_len": r.seq, + "experts": r.experts, + "top_k": r.top_k, + "dtype": r.dtype, + "naive_ms": f"{r.naive_ms:.4f}", + "grouped_ms": f"{r.grouped_ms:.4f}", + "speedup": f"{r.speedup:.4f}", + "naive_tflops": f"{r.naive_tflops:.4f}", + "grouped_tflops": f"{r.grouped_tflops:.4f}", + "max_abs": f"{r.max_abs:.6e}", + "mean_abs": f"{r.mean_abs:.6e}", + "rel_l2": f"{r.rel_l2:.6e}", + } + ) + + +def main() -> None: + args = _parse_args() + dtype = _map_dtype(args.dtype) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + batch_sizes = _parse_int_list(args.batch_sizes) + seq_lens = _parse_int_list(args.seq_lens) + experts_list = _parse_int_list(args.experts) + top_ks = _parse_int_list(args.top_ks) + + results: List[SweepResult] = [] + _print_header(args.hidden, args.inter, dtype, device) + + for bsz in batch_sizes: + for seq in seq_lens: + for experts in experts_list: + for top_k in top_ks: + try: + res = _run_case( + bsz=bsz, + seq=seq, + experts=experts, + top_k=top_k, + hidden=args.hidden, + inter=args.inter, + dtype=dtype, + device=device, + iters=args.iters, + warmup=args.warmup, + init_std=args.init_std, + score_before=args.score_before, + score_func=args.score_func, + route_norm=args.route_norm, + ) + except RuntimeError as err: + print( + f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}", + file=sys.stderr, + ) + continue + results.append(res) + _print_result(res) + + if args.csv and results: + _write_csv(args.csv, results) + print(f"Wrote {len(results)} rows to {args.csv}") + + +if __name__ == "__main__": + main()