diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 3cb434edb..78fdff374 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -1,121 +1,83 @@ #!/usr/bin/env python +"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm.""" + +from __future__ import annotations + import argparse import sys import time from pathlib import Path import torch -import torch.nn as nn -import torch.nn.functional as F try: from axolotl.kernels.moe import torch_grouped as tg -except Exception: # pragma: no cover - fallback when torch_grouped unavailable +except Exception: # pragma: no cover tg = None -class SwiGLUMlp(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int): - super().__init__() - self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) - self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) - self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) - self.act_fn = F.silu - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) - - -class Experts(nn.Module): - def __init__(self, num_experts: int, hidden_size: int, intermediate_size: int): - super().__init__() - self.layers = nn.ModuleList( - SwiGLUMlp(hidden_size, intermediate_size) for _ in range(num_experts) - ) - self.num_experts = num_experts - - def __getitem__(self, idx): - return self.layers[idx] - - -def forward_naive( - hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int -): - bsz, seqlen, hdim = hidden_states.shape - x = hidden_states.view(-1, hdim) - router_logits = gate(x) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype) - x_rep = x.repeat_interleave(top_k, dim=0) - y = torch.empty_like(x_rep) - flat_idx = topk_idx.view(-1) - for i in range(experts.num_experts): - sel = flat_idx == i - if sel.any(): - y[sel] = experts[i](x_rep[sel]) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - return y.view(bsz, seqlen, hdim) - - -def bench(fn, iters=50, warmup=10, sync=True): - # warmup +def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for _ in range(warmup): - fn() - if sync and torch.cuda.is_available(): + run() + if sync and device.type == "cuda": torch.cuda.synchronize() - # measure times = [] for _ in range(iters): - if sync and torch.cuda.is_available(): + if sync and device.type == "cuda": torch.cuda.synchronize() - t0 = time.perf_counter() - fn() - if sync and torch.cuda.is_available(): + start = time.perf_counter() + run() + if sync and device.type == "cuda": torch.cuda.synchronize() - dt = (time.perf_counter() - t0) * 1000.0 - times.append(dt) + times.append((time.perf_counter() - start) * 1000.0) return sum(times) / len(times) def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float: - """Estimate GEMM FLOPs for a SwiGLU MoE MLP. - - Two up projections (w1,w3) + one down (w2), each token processed by top_k experts. - FLOPs ≈ 6 * (tokens * top_k) * hidden * inter (2*m*k*n per GEMM). - """ - m_rep = tokens * top_k - return 6.0 * m_rep * hidden * inter + return 6.0 * tokens * top_k * hidden * inter -def main(): - p = argparse.ArgumentParser(description="MoE microbenchmark") +def load_hf_block(hidden: int, inter: int, experts: int, top_k: int, *, device: torch.device, dtype: torch.dtype): + project_root = Path(__file__).resolve().parents[2] + transformers_src = project_root / "transformers" / "src" + if transformers_src.exists() and str(transformers_src) not in sys.path: + sys.path.append(str(transformers_src)) + + from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + cfg = Qwen2MoeConfig( + hidden_size=hidden, + moe_intermediate_size=inter, + shared_expert_intermediate_size=inter, + num_experts=experts, + num_experts_per_tok=top_k, + norm_topk_prob=True, + qkv_bias=True, + ) + + block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) + block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) + block_grouped.load_state_dict(block.state_dict()) + return block, block_grouped + + +def main() -> None: + p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark") p.add_argument("--bsz", type=int, default=8) p.add_argument("--seq", type=int, default=1024) p.add_argument("--hidden", type=int, default=4096) p.add_argument("--inter", type=int, default=14336) - p.add_argument("--experts", type=int, default=8) - p.add_argument("--top_k", type=int, default=2) - p.add_argument( - "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] - ) + p.add_argument("--experts", type=int, default=32) + p.add_argument("--top_k", type=int, default=4) + p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10) - p.add_argument( - "--hf-block", - type=str, - default="none", - choices=["none", "qwen2_moe"], - help="Use a Hugging Face MoE block for benchmarking instead of the toy SwiGLU layer.", - ) - p.add_argument( - "--profile", - action="store_true", - help="Capture CUDA profiler tables for naive and grouped runs.", - ) + p.add_argument("--profile", action="store_true") args = p.parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = { "bf16": torch.bfloat16, "fp16": torch.float16, @@ -123,138 +85,80 @@ def main(): }[args.dtype] torch.manual_seed(0) - if device == "cuda": + if device.type == "cuda": torch.cuda.manual_seed(0) - # data - x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) - - tokens = args.bsz * args.seq - - use_hf = args.hf_block != "none" - - if use_hf: - project_root = Path(__file__).resolve().parents[2] - transformers_src = project_root / "transformers" / "src" - if transformers_src.exists() and str(transformers_src) not in sys.path: - sys.path.append(str(transformers_src)) - - if args.hf_block == "qwen2_moe": - from transformers.models.qwen2_moe.configuration_qwen2_moe import ( - Qwen2MoeConfig, - ) - from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - Qwen2MoeSparseMoeBlock, - ) - - cfg = Qwen2MoeConfig( - hidden_size=args.hidden, - moe_intermediate_size=args.inter, - shared_expert_intermediate_size=args.inter, - num_experts=args.experts, - num_experts_per_tok=args.top_k, - norm_topk_prob=True, - qkv_bias=True, - ) - - block_naive = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) - block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype) - block_grouped.load_state_dict(block_naive.state_dict()) - - def run_naive_model(inp: torch.Tensor) -> torch.Tensor: - out, _ = block_naive(inp) - return out - - def run_grouped_model(inp: torch.Tensor) -> torch.Tensor: - if tg is None or not tg.available(): - return torch.empty(0) - block_grouped.experts._ax_parent_block = block_grouped - y, _ = tg.moe_ffn_forward_grouped( - inp, block_grouped.gate, block_grouped.experts, block_grouped.top_k - ) - return y if y is not None else torch.empty(0) - - flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) - print( - f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} " - f"experts={args.experts} top_k={args.top_k} hf_block={args.hf_block}" - ) - - else: - experts = Experts(args.experts, args.hidden, args.inter).to( - device=device, dtype=dtype - ) - gate = nn.Linear(args.hidden, args.experts, bias=False).to( - device=device, dtype=dtype - ) - - def run_naive_model(inp: torch.Tensor) -> torch.Tensor: - return forward_naive(inp, gate, experts, args.top_k) - - def run_grouped_model(inp: torch.Tensor) -> torch.Tensor: - if tg is None or not tg.available(): - return torch.empty(0) - y, _ = tg.moe_ffn_forward_grouped(inp, gate, experts, args.top_k) - return y if y is not None else torch.empty(0) - - flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) - print( - f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} " - f"experts={args.experts} top_k={args.top_k}" - ) - - # Benchmark naive - t_naive = bench(lambda: run_naive_model(x), iters=args.iters, warmup=args.warmup) - tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12) - print( - f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" + block_naive, block_grouped = load_hf_block( + args.hidden, + args.inter, + args.experts, + args.top_k, + device=device, + dtype=dtype, ) + tokens = args.bsz * args.seq + flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) + print( + f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} " + f"experts={args.experts} top_k={args.top_k}" + ) + + x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) + + def run_naive(): + y, _ = block_naive(x) + return y + + def run_grouped(): + if tg is None or not tg.available(): + return torch.empty(0) + block_grouped.experts._ax_parent_block = block_grouped + y, _ = tg.moe_ffn_forward_grouped(x, block_grouped.gate, block_grouped.experts, block_grouped.top_k) + return y if y is not None else torch.empty(0) + + t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup) + tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12) + print(f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s") + with torch.no_grad(): - y_ref = run_naive_model(x) + y_ref = run_naive() - # Benchmark grouped - if tg is not None and tg.available(): - y_grouped = run_grouped_model(x) - if y_grouped.numel() == 0: - print("torch_grouped\tN/A (op not callable)") - else: - t_grouped = bench( - lambda: run_grouped_model(x), - iters=args.iters, - warmup=args.warmup, - ) - tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12) - speedup = t_naive / t_grouped - print( - f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000):.1f} tok/s\t" - f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×" - ) - diff = (y_ref.float() - y_grouped.float()).abs() - max_abs = diff.max().item() - mean_abs = diff.mean().item() - rel_l2 = ( - (diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item() - ) - print( - f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" - ) - else: + if tg is None or not tg.available(): print("torch_grouped\tN/A (unavailable)") + return - if args.profile and tg is not None and tg.available(): + y_grouped = run_grouped() + if y_grouped.numel() == 0: + print("torch_grouped\tN/A (op not callable)") + return + + t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup) + tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12) + speedup = t_naive / t_grouped + print( + f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t" + f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×" + ) + + diff = (y_ref.float() - y_grouped.float()).abs() + print( + "torch_grouped_check: " + f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} " + f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}" + ) + + if args.profile: with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True ) as prof: - run_naive_model(x) + run_naive() print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], - record_shapes=True, - with_stack=False, + activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True ) as prof: - run_grouped_model(x) + run_grouped() print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) diff --git a/scripts/debug_qwen2_experts.py b/scripts/debug_qwen2_experts.py new file mode 100644 index 000000000..9b9057689 --- /dev/null +++ b/scripts/debug_qwen2_experts.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import torch + +ROOT = Path(__file__).resolve().parents[2] +sys.path.extend( + [ + str(ROOT / "transformers" / "src"), + str(ROOT / "src"), + ] +) + +from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig +from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + +from axolotl.kernels.moe.torch_grouped import _iter_expert_impls + + +def main() -> None: + cfg = Qwen2MoeConfig( + hidden_size=4096, + moe_intermediate_size=14336, + shared_expert_intermediate_size=14336, + num_experts=32, + num_experts_per_tok=4, + ) + + block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16) + experts = block.experts + setattr(experts, "_ax_parent_block", block) + + impls = _iter_expert_impls(experts) + print(f"impl count: {len(impls)}") + for idx, impl in enumerate(impls[:8]): + has_gate = hasattr(impl, "gate_proj") + has_up = hasattr(impl, "up_proj") + print( + f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}" + ) + if has_gate: + print(f" gate shape {tuple(impl.gate_proj.weight.shape)}") + print(f" up shape {tuple(impl.up_proj.weight.shape)}") + print(f" down shape {tuple(impl.down_proj.weight.shape)}") + + +if __name__ == "__main__": + main()