bench fix

This commit is contained in:
Dan Saunders
2025-09-19 12:20:35 -04:00
parent 63544ce709
commit 1e7302d30a
2 changed files with 159 additions and 202 deletions

View File

@@ -1,121 +1,83 @@
#!/usr/bin/env python #!/usr/bin/env python
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
from __future__ import annotations
import argparse import argparse
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
try: try:
from axolotl.kernels.moe import torch_grouped as tg from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover - fallback when torch_grouped unavailable except Exception: # pragma: no cover
tg = None tg = None
class SwiGLUMlp(nn.Module): def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
def __init__(self, hidden_size: int, intermediate_size: int): device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
self.act_fn = F.silu
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(self.act_fn(self.w1(x)) * self.w3(x))
class Experts(nn.Module):
def __init__(self, num_experts: int, hidden_size: int, intermediate_size: int):
super().__init__()
self.layers = nn.ModuleList(
SwiGLUMlp(hidden_size, intermediate_size) for _ in range(num_experts)
)
self.num_experts = num_experts
def __getitem__(self, idx):
return self.layers[idx]
def forward_naive(
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
):
bsz, seqlen, hdim = hidden_states.shape
x = hidden_states.view(-1, hdim)
router_logits = gate(x)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
x_rep = x.repeat_interleave(top_k, dim=0)
y = torch.empty_like(x_rep)
flat_idx = topk_idx.view(-1)
for i in range(experts.num_experts):
sel = flat_idx == i
if sel.any():
y[sel] = experts[i](x_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
return y.view(bsz, seqlen, hdim)
def bench(fn, iters=50, warmup=10, sync=True):
# warmup
for _ in range(warmup): for _ in range(warmup):
fn() run()
if sync and torch.cuda.is_available(): if sync and device.type == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
# measure
times = [] times = []
for _ in range(iters): for _ in range(iters):
if sync and torch.cuda.is_available(): if sync and device.type == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.perf_counter() start = time.perf_counter()
fn() run()
if sync and torch.cuda.is_available(): if sync and device.type == "cuda":
torch.cuda.synchronize() torch.cuda.synchronize()
dt = (time.perf_counter() - t0) * 1000.0 times.append((time.perf_counter() - start) * 1000.0)
times.append(dt)
return sum(times) / len(times) return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float: def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
"""Estimate GEMM FLOPs for a SwiGLU MoE MLP. return 6.0 * tokens * top_k * hidden * inter
Two up projections (w1,w3) + one down (w2), each token processed by top_k experts.
FLOPs ≈ 6 * (tokens * top_k) * hidden * inter (2*m*k*n per GEMM).
"""
m_rep = tokens * top_k
return 6.0 * m_rep * hidden * inter
def main(): def load_hf_block(hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype):
p = argparse.ArgumentParser(description="MoE microbenchmark") project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8) p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024) p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096) p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336) p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=8) p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=2) p.add_argument("--top_k", type=int, default=4)
p.add_argument( p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
"--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
)
p.add_argument("--iters", type=int, default=50) p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10) p.add_argument("--warmup", type=int, default=10)
p.add_argument( p.add_argument("--profile", action="store_true")
"--hf-block",
type=str,
default="none",
choices=["none", "qwen2_moe"],
help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.",
)
p.add_argument(
"--profile",
action="store_true",
help="Capture CUDA profiler tables for naive and grouped runs.",
)
args = p.parse_args() args = p.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = { dtype = {
"bf16": torch.bfloat16, "bf16": torch.bfloat16,
"fp16": torch.float16, "fp16": torch.float16,
@@ -123,138 +85,80 @@ def main():
}[args.dtype] }[args.dtype]
torch.manual_seed(0) torch.manual_seed(0)
if device == "cuda": if device.type == "cuda":
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
# data block_naive, block_grouped = load_hf_block(
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) args.hidden,
args.inter,
tokens = args.bsz * args.seq args.experts,
args.top_k,
use_hf = args.hf_block != "none" device=device,
dtype=dtype,
if use_hf:
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
if args.hf_block == "qwen2_moe":
from transformers.models.qwen2_moe.configuration_qwen2_moe import (
Qwen2MoeConfig,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
Qwen2MoeSparseMoeBlock,
)
cfg = Qwen2MoeConfig(
hidden_size=args.hidden,
moe_intermediate_size=args.inter,
shared_expert_intermediate_size=args.inter,
num_experts=args.experts,
num_experts_per_tok=args.top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block_naive.state_dict())
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
out, _ = block_naive(inp)
return out
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
if tg is None or not tg.available():
return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped
y, _ = tg.moe_ffn_forward_grouped(
inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k
)
return y if y is not None else torch.empty(0)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}"
)
else:
experts = Experts(args.experts, args.hidden, args.inter).to(
device=device, dtype=dtype
)
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
device=device, dtype=dtype
)
def run_naive_model(inp: torch.Tensor) -> torch.Tensor:
return forward_naive(inp, gate, experts, args.top_k)
def run_grouped_model(inp: torch.Tensor) -> torch.Tensor:
if tg is None or not tg.available():
return torch.empty(0)
y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k)
return y if y is not None else torch.empty(0)
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
# Benchmark naive
t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
) )
tokens = args.bsz * args.seq
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
def run_naive():
y, _ = block_naive(x)
return y
def run_grouped():
if tg is None or not tg.available():
return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped
y, _ = tg.moe_ffn_forward_grouped(x, block_grouped.gate, block_grouped.experts, block_grouped.top_k)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s")
with torch.no_grad(): with torch.no_grad():
y_ref = run_naive_model(x) y_ref = run_naive()
# Benchmark grouped if tg is None or not tg.available():
if tg is not None and tg.available():
y_grouped = run_grouped_model(x)
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
else:
t_grouped = bench(
lambda: run_grouped_model(x),
iters=args.iters,
warmup=args.warmup,
)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
print("torch_grouped\tN/A (unavailable)") print("torch_grouped\tN/A (unavailable)")
return
if args.profile and tg is not None and tg.available(): y_grouped = run_grouped()
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
return
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
print(
"torch_grouped_check: "
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
)
if args.profile:
with torch.profiler.profile( with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof: ) as prof:
run_naive_model(x) run_naive()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile( with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
record_shapes=True,
with_stack=False,
) as prof: ) as prof:
run_grouped_model(x) run_grouped()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
from __future__ import annotations
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
sys.path.extend(
[
str(ROOT / "transformers" / "src"),
str(ROOT / "src"),
]
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
def main() -> None:
cfg = Qwen2MoeConfig(
hidden_size=4096,
moe_intermediate_size=14336,
shared_expert_intermediate_size=14336,
num_experts=32,
num_experts_per_tok=4,
)
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
setattr(experts, "_ax_parent_block", block)
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")
for idx, impl in enumerate(impls[:8]):
has_gate = hasattr(impl, "gate_proj")
has_up = hasattr(impl, "up_proj")
print(
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
)
if has_gate:
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
if __name__ == "__main__":
main()