329 lines
10 KiB
Python
329 lines
10 KiB
Python
#!/usr/bin/env python
|
|
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable, List
|
|
|
|
import torch
|
|
|
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
|
|
if str(_TITAN_PATH) not in sys.path:
|
|
sys.path.append(str(_TITAN_PATH))
|
|
|
|
from torchtitan.models.moe import MoE, MoEArgs
|
|
|
|
|
|
def _parse_int_list(value: str) -> List[int]:
|
|
return [int(v) for v in value.split(",") if v]
|
|
|
|
|
|
def _parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
|
|
p.add_argument(
|
|
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
|
|
)
|
|
p.add_argument(
|
|
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
|
|
)
|
|
p.add_argument(
|
|
"--experts", default="8,16,32,64", help="Comma separated expert counts"
|
|
)
|
|
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
|
|
p.add_argument("--hidden", type=int, default=4096)
|
|
p.add_argument("--inter", type=int, default=14336)
|
|
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
|
p.add_argument("--iters", type=int, default=25)
|
|
p.add_argument("--warmup", type=int, default=5)
|
|
p.add_argument("--init-std", type=float, default=0.02)
|
|
p.add_argument("--score-before", action="store_true")
|
|
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
|
|
p.add_argument("--route-norm", action="store_true")
|
|
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
|
|
return p.parse_args()
|
|
|
|
|
|
def _map_dtype(arg: str) -> torch.dtype:
|
|
return {
|
|
"bf16": torch.bfloat16,
|
|
"fp16": torch.float16,
|
|
"fp32": torch.float32,
|
|
}[arg]
|
|
|
|
|
|
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
|
return 6.0 * tokens * top_k * hidden * inter
|
|
|
|
|
|
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
|
|
module = module.to(device=device)
|
|
for param in module.parameters():
|
|
param.data = param.data.to(dtype)
|
|
if param.grad is not None:
|
|
param.grad = None
|
|
for name, buf in module.named_buffers():
|
|
if name == "tokens_per_expert":
|
|
module._buffers[name] = torch.zeros_like(
|
|
buf, dtype=torch.float32, device=device
|
|
)
|
|
elif name == "expert_bias" and buf is not None:
|
|
module._buffers[name] = torch.zeros_like(
|
|
buf, dtype=torch.float32, device=device
|
|
)
|
|
else:
|
|
module._buffers[name] = buf.to(device=device, dtype=dtype)
|
|
module.eval()
|
|
return module
|
|
|
|
|
|
@torch.inference_mode()
|
|
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
|
|
return module(x)
|
|
|
|
|
|
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
|
|
for _ in range(warmup):
|
|
callable_()
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
timings: List[float] = []
|
|
for _ in range(iters):
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
callable_()
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
timings.append((time.perf_counter() - start) * 1000.0)
|
|
return sum(timings) / len(timings)
|
|
|
|
|
|
@dataclass
|
|
class SweepResult:
|
|
bsz: int
|
|
seq: int
|
|
experts: int
|
|
top_k: int
|
|
dtype: str
|
|
naive_ms: float
|
|
grouped_ms: float
|
|
speedup: float
|
|
naive_tflops: float
|
|
grouped_tflops: float
|
|
max_abs: float
|
|
mean_abs: float
|
|
rel_l2: float
|
|
|
|
|
|
def _run_case(
|
|
*,
|
|
bsz: int,
|
|
seq: int,
|
|
experts: int,
|
|
top_k: int,
|
|
hidden: int,
|
|
inter: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
iters: int,
|
|
warmup: int,
|
|
init_std: float,
|
|
score_before: bool,
|
|
score_func: str,
|
|
route_norm: bool,
|
|
) -> SweepResult:
|
|
torch.manual_seed(0)
|
|
if device.type == "cuda":
|
|
torch.cuda.manual_seed(0)
|
|
|
|
moe_args_grouped = MoEArgs(
|
|
num_experts=experts,
|
|
num_shared_experts=0,
|
|
score_func=score_func,
|
|
route_norm=route_norm,
|
|
top_k=top_k,
|
|
use_grouped_mm=True,
|
|
score_before_experts=score_before,
|
|
load_balance_coeff=None,
|
|
)
|
|
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
|
|
moe_grouped.init_weights(init_std, buffer_device=device)
|
|
|
|
moe_args_naive = MoEArgs(
|
|
num_experts=experts,
|
|
num_shared_experts=0,
|
|
score_func=score_func,
|
|
route_norm=route_norm,
|
|
top_k=top_k,
|
|
use_grouped_mm=False,
|
|
score_before_experts=score_before,
|
|
load_balance_coeff=None,
|
|
)
|
|
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
|
|
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
|
|
|
|
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
|
|
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
|
|
|
|
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
|
|
|
|
def run_naive():
|
|
if hasattr(moe_naive, "tokens_per_expert"):
|
|
moe_naive.tokens_per_expert.zero_()
|
|
return _forward(moe_naive, x)
|
|
|
|
def run_grouped():
|
|
if hasattr(moe_grouped, "tokens_per_expert"):
|
|
moe_grouped.tokens_per_expert.zero_()
|
|
return _forward(moe_grouped, x)
|
|
|
|
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
|
|
y_naive = run_naive()
|
|
|
|
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
|
|
y_grouped = run_grouped()
|
|
|
|
diff = (y_naive.float() - y_grouped.float()).abs()
|
|
max_abs = diff.max().item()
|
|
mean_abs = diff.mean().item()
|
|
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
|
|
|
tokens = bsz * seq
|
|
flops = _estimate_flops(tokens, hidden, inter, top_k)
|
|
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
|
|
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
|
|
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
|
|
|
|
return SweepResult(
|
|
bsz=bsz,
|
|
seq=seq,
|
|
experts=experts,
|
|
top_k=top_k,
|
|
dtype=str(dtype),
|
|
naive_ms=naive_ms,
|
|
grouped_ms=grouped_ms,
|
|
speedup=speedup,
|
|
naive_tflops=naive_tflops,
|
|
grouped_tflops=grouped_tflops,
|
|
max_abs=max_abs,
|
|
mean_abs=mean_abs,
|
|
rel_l2=rel_l2,
|
|
)
|
|
|
|
|
|
def _print_header(
|
|
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
|
|
) -> None:
|
|
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
|
|
print(
|
|
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
|
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
|
)
|
|
|
|
|
|
def _print_result(res: SweepResult) -> None:
|
|
print(
|
|
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
|
|
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
|
|
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
|
|
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
|
|
)
|
|
|
|
|
|
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
|
|
fieldnames = [
|
|
"batch_size",
|
|
"seq_len",
|
|
"experts",
|
|
"top_k",
|
|
"dtype",
|
|
"naive_ms",
|
|
"grouped_ms",
|
|
"speedup",
|
|
"naive_tflops",
|
|
"grouped_tflops",
|
|
"max_abs",
|
|
"mean_abs",
|
|
"rel_l2",
|
|
]
|
|
with path.open("w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for r in results:
|
|
writer.writerow(
|
|
{
|
|
"batch_size": r.bsz,
|
|
"seq_len": r.seq,
|
|
"experts": r.experts,
|
|
"top_k": r.top_k,
|
|
"dtype": r.dtype,
|
|
"naive_ms": f"{r.naive_ms:.4f}",
|
|
"grouped_ms": f"{r.grouped_ms:.4f}",
|
|
"speedup": f"{r.speedup:.4f}",
|
|
"naive_tflops": f"{r.naive_tflops:.4f}",
|
|
"grouped_tflops": f"{r.grouped_tflops:.4f}",
|
|
"max_abs": f"{r.max_abs:.6e}",
|
|
"mean_abs": f"{r.mean_abs:.6e}",
|
|
"rel_l2": f"{r.rel_l2:.6e}",
|
|
}
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = _parse_args()
|
|
dtype = _map_dtype(args.dtype)
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
batch_sizes = _parse_int_list(args.batch_sizes)
|
|
seq_lens = _parse_int_list(args.seq_lens)
|
|
experts_list = _parse_int_list(args.experts)
|
|
top_ks = _parse_int_list(args.top_ks)
|
|
|
|
results: List[SweepResult] = []
|
|
_print_header(args.hidden, args.inter, dtype, device)
|
|
|
|
for bsz in batch_sizes:
|
|
for seq in seq_lens:
|
|
for experts in experts_list:
|
|
for top_k in top_ks:
|
|
try:
|
|
res = _run_case(
|
|
bsz=bsz,
|
|
seq=seq,
|
|
experts=experts,
|
|
top_k=top_k,
|
|
hidden=args.hidden,
|
|
inter=args.inter,
|
|
dtype=dtype,
|
|
device=device,
|
|
iters=args.iters,
|
|
warmup=args.warmup,
|
|
init_std=args.init_std,
|
|
score_before=args.score_before,
|
|
score_func=args.score_func,
|
|
route_norm=args.route_norm,
|
|
)
|
|
except RuntimeError as err:
|
|
print(
|
|
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
|
|
file=sys.stderr,
|
|
)
|
|
continue
|
|
results.append(res)
|
|
_print_result(res)
|
|
|
|
if args.csv and results:
|
|
_write_csv(args.csv, results)
|
|
print(f"Wrote {len(results)} rows to {args.csv}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|