263 lines
9.2 KiB
Python
263 lines
9.2 KiB
Python
#!/usr/bin/env python
|
||
import argparse
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
try:
|
||
from axolotl.kernels.moe import torch_grouped as tg
|
||
except Exception: # pragma: no cover - fallback when torch_grouped unavailable
|
||
tg = None
|
||
|
||
|
||
class SwiGLUMlp(nn.Module):
|
||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||
super().__init__()
|
||
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||
self.act_fn = F.silu
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.w2(self.act_fn(self.w1(x)) * self.w3(x))
|
||
|
||
|
||
class Experts(nn.Module):
|
||
def __init__(self, num_experts: int, hidden_size: int, intermediate_size: int):
|
||
super().__init__()
|
||
self.layers = nn.ModuleList(
|
||
SwiGLUMlp(hidden_size, intermediate_size) for _ in range(num_experts)
|
||
)
|
||
self.num_experts = num_experts
|
||
|
||
def __getitem__(self, idx):
|
||
return self.layers[idx]
|
||
|
||
|
||
def forward_naive(
|
||
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
|
||
):
|
||
bsz, seqlen, hdim = hidden_states.shape
|
||
x = hidden_states.view(-1, hdim)
|
||
router_logits = gate(x)
|
||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||
y = torch.empty_like(x_rep)
|
||
flat_idx = topk_idx.view(-1)
|
||
for i in range(experts.num_experts):
|
||
sel = flat_idx == i
|
||
if sel.any():
|
||
y[sel] = experts[i](x_rep[sel])
|
||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||
return y.view(bsz, seqlen, hdim)
|
||
|
||
|
||
def bench(fn, iters=50, warmup=10, sync=True):
|
||
# warmup
|
||
for _ in range(warmup):
|
||
fn()
|
||
if sync and torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
# measure
|
||
times = []
|
||
for _ in range(iters):
|
||
if sync and torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
t0 = time.perf_counter()
|
||
fn()
|
||
if sync and torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
dt = (time.perf_counter() - t0) * 1000.0
|
||
times.append(dt)
|
||
return sum(times) / len(times)
|
||
|
||
|
||
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||
"""Estimate GEMM FLOPs for a SwiGLU MoE MLP.
|
||
|
||
Two up projections (w1,w3) + one down (w2), each token processed by top_k experts.
|
||
FLOPs ≈ 6 * (tokens * top_k) * hidden * inter (2*m*k*n per GEMM).
|
||
"""
|
||
m_rep = tokens * top_k
|
||
return 6.0 * m_rep * hidden * inter
|
||
|
||
|
||
def main():
|
||
p = argparse.ArgumentParser(description="MoE microbenchmark")
|
||
p.add_argument("--bsz", type=int, default=8)
|
||
p.add_argument("--seq", type=int, default=1024)
|
||
p.add_argument("--hidden", type=int, default=4096)
|
||
p.add_argument("--inter", type=int, default=14336)
|
||
p.add_argument("--experts", type=int, default=8)
|
||
p.add_argument("--top_k", type=int, default=2)
|
||
p.add_argument(
|
||
"--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
|
||
)
|
||
p.add_argument("--iters", type=int, default=50)
|
||
p.add_argument("--warmup", type=int, default=10)
|
||
p.add_argument(
|
||
"--hf-block",
|
||
type=str,
|
||
default="none",
|
||
choices=["none", "qwen2_moe"],
|
||
help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.",
|
||
)
|
||
p.add_argument(
|
||
"--profile",
|
||
action="store_true",
|
||
help="Capture CUDA profiler tables for naive and grouped runs.",
|
||
)
|
||
args = p.parse_args()
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
dtype = {
|
||
"bf16": torch.bfloat16,
|
||
"fp16": torch.float16,
|
||
"fp32": torch.float32,
|
||
}[args.dtype]
|
||
|
||
torch.manual_seed(0)
|
||
if device == "cuda":
|
||
torch.cuda.manual_seed(0)
|
||
|
||
# data
|
||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||
|
||
tokens = args.bsz * args.seq
|
||
|
||
use_hf = args.hf_block != "none"
|
||
|
||
if use_hf:
|
||
project_root = Path(__file__).resolve().parents[2]
|
||
transformers_src = project_root / "transformers" / "src"
|
||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||
sys.path.append(str(transformers_src))
|
||
|
||
if args.hf_block == "qwen2_moe":
|
||
from transformers.models.qwen2_moe.configuration_qwen2_moe import (
|
||
Qwen2MoeConfig,
|
||
)
|
||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||
Qwen2MoeSparseMoeBlock,
|
||
)
|
||
|
||
cfg = Qwen2MoeConfig(
|
||
hidden_size=args.hidden,
|
||
moe_intermediate_size=args.inter,
|
||
shared_expert_intermediate_size=args.inter,
|
||
num_experts=args.experts,
|
||
num_experts_per_tok=args.top_k,
|
||
norm_topk_prob=True,
|
||
qkv_bias=True,
|
||
)
|
||
|
||
block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||
block_grouped.load_state_dict(block_naive.state_dict())
|
||
|
||
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
|
||
out, _ = block_naive(inp)
|
||
return out
|
||
|
||
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
|
||
if tg is None or not tg.available():
|
||
return torch.empty(0)
|
||
block_grouped.experts._ax_parent_block = block_grouped
|
||
y, _ = tg.moe_ffn_forward_grouped(
|
||
inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k
|
||
)
|
||
return y if y is not None else torch.empty(0)
|
||
|
||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||
print(
|
||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
|
||
f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}"
|
||
)
|
||
|
||
else:
|
||
experts = Experts(args.experts, args.hidden, args.inter).to(
|
||
device=device, dtype=dtype
|
||
)
|
||
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
|
||
device=device, dtype=dtype
|
||
)
|
||
|
||
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
|
||
return forward_naive(inp, gate, experts, args.top_k)
|
||
|
||
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
|
||
if tg is None or not tg.available():
|
||
return torch.empty(0)
|
||
y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k)
|
||
return y if y is not None else torch.empty(0)
|
||
|
||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||
print(
|
||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
|
||
f"experts={args.experts} top_k={args.top_k}"
|
||
)
|
||
|
||
# Benchmark naive
|
||
t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup)
|
||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||
print(
|
||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||
)
|
||
|
||
with torch.no_grad():
|
||
y_ref = run_naive_model(x)
|
||
|
||
# Benchmark grouped
|
||
if tg is not None and tg.available():
|
||
y_grouped = run_grouped_model(x)
|
||
if y_grouped.numel() == 0:
|
||
print("torch_grouped\tN/A (op not callable)")
|
||
else:
|
||
t_grouped = bench(
|
||
lambda: run_grouped_model(x),
|
||
iters=args.iters,
|
||
warmup=args.warmup,
|
||
)
|
||
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
|
||
speedup = t_naive / t_grouped
|
||
print(
|
||
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t"
|
||
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
|
||
)
|
||
diff = (y_ref.float() - y_grouped.float()).abs()
|
||
max_abs = diff.max().item()
|
||
mean_abs = diff.mean().item()
|
||
rel_l2 = (
|
||
(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||
)
|
||
print(
|
||
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||
)
|
||
else:
|
||
print("torch_grouped\tN/A (unavailable)")
|
||
|
||
if args.profile and tg is not None and tg.available():
|
||
with torch.profiler.profile(
|
||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||
) as prof:
|
||
run_naive_model(x)
|
||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||
|
||
with torch.profiler.profile(
|
||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||
record_shapes=True,
|
||
with_stack=False,
|
||
) as prof:
|
||
run_grouped_model(x)
|
||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|