191 lines
5.2 KiB
Python
191 lines
5.2 KiB
Python
#!/usr/bin/env python
|
|
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
|
|
CURRENT_DIR = Path(__file__).resolve().parent
|
|
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
|
repo_root = candidate / "axolotl"
|
|
if repo_root.exists():
|
|
if str(repo_root) not in sys.path:
|
|
sys.path.insert(0, str(repo_root))
|
|
break
|
|
else:
|
|
raise SystemExit("Unable to locate axolotl repository root for imports")
|
|
|
|
from axolotl.kernels.moe import (
|
|
cg_grouped_gemm_forward,
|
|
cg_grouped_gemm_forward_dynamic,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Scenario:
|
|
num_groups: int
|
|
m: int
|
|
n: int
|
|
k: int
|
|
|
|
|
|
SCENARIOS: tuple[Scenario, ...] = (
|
|
Scenario(num_groups=4, m=8192, n=4096, k=7168),
|
|
Scenario(num_groups=4, m=8192, n=7168, k=2048),
|
|
Scenario(num_groups=8, m=4096, n=4096, k=7168),
|
|
Scenario(num_groups=8, m=4096, n=7168, k=2048),
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--device", default="cuda", choices=["cuda"], help="Execution device"
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="bf16",
|
|
choices=["bf16", "fp16", "fp32"],
|
|
help="Computation dtype",
|
|
)
|
|
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
|
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
|
parser.add_argument(
|
|
"--group-size",
|
|
type=int,
|
|
default=128,
|
|
help="GROUP_SIZE_M expected by the kernel",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def pick_dtype(name: str) -> torch.dtype:
|
|
return {
|
|
"bf16": torch.bfloat16,
|
|
"fp16": torch.float16,
|
|
"fp32": torch.float32,
|
|
}[name]
|
|
|
|
|
|
def make_indices(
|
|
num_groups: int, group_size: int, device: torch.device
|
|
) -> torch.Tensor:
|
|
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
|
return indices.repeat_interleave(group_size)
|
|
|
|
|
|
def timed_call(fn, *args, warmup: int, iters: int) -> float:
|
|
for _ in range(warmup):
|
|
fn(*args)
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
for _ in range(iters):
|
|
fn(*args)
|
|
torch.cuda.synchronize()
|
|
return (time.perf_counter() - start) * 1000.0 / iters
|
|
|
|
|
|
def run_scenario(
|
|
scenario: Scenario,
|
|
*,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
warmup: int,
|
|
iters: int,
|
|
group_size_m: int,
|
|
) -> dict:
|
|
if scenario.m % scenario.num_groups != 0:
|
|
raise ValueError(
|
|
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
|
|
)
|
|
group_size = scenario.m // scenario.num_groups
|
|
if group_size % group_size_m != 0:
|
|
raise ValueError(
|
|
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
|
|
)
|
|
|
|
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
|
weights = torch.randn(
|
|
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
|
|
)
|
|
indices = make_indices(scenario.num_groups, group_size, device)
|
|
|
|
def persistent():
|
|
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
|
|
|
|
def baseline():
|
|
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
|
|
|
|
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
|
|
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
|
|
|
|
return {
|
|
"scenario": scenario,
|
|
"persistent_ms": persistent_ms,
|
|
"baseline_ms": baseline_ms,
|
|
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
|
|
}
|
|
|
|
|
|
def main() -> None: # pragma: no cover - utility script
|
|
args = parse_args()
|
|
torch.manual_seed(args.seed)
|
|
|
|
if args.device != "cuda" or not torch.cuda.is_available():
|
|
raise SystemExit("CUDA device required for this benchmark")
|
|
|
|
dtype = pick_dtype(args.dtype)
|
|
device = torch.device(args.device)
|
|
|
|
print(
|
|
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
|
|
)
|
|
print(
|
|
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
|
|
)
|
|
for result in run_all(
|
|
SCENARIOS,
|
|
dtype=dtype,
|
|
device=device,
|
|
warmup=args.warmup,
|
|
iters=args.iters,
|
|
group_size_m=args.group_size,
|
|
):
|
|
scen = result["scenario"]
|
|
print(
|
|
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
|
|
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
|
)
|
|
|
|
|
|
def run_all(
|
|
scenarios: Iterable[Scenario],
|
|
*,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
warmup: int,
|
|
iters: int,
|
|
group_size_m: int,
|
|
) -> Iterable[dict]:
|
|
for scenario in scenarios:
|
|
yield run_scenario(
|
|
scenario,
|
|
dtype=dtype,
|
|
device=device,
|
|
warmup=warmup,
|
|
iters=iters,
|
|
group_size_m=group_size_m,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|