diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 0f2ae7d12..777a88024 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -109,9 +109,6 @@ def main(): ) p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10) - p.add_argument( - "--check", action="store_true", help="Check numerical equivalence (outputs)" - ) args = p.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" @@ -166,30 +163,56 @@ def main(): t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup ) tflops = flops_total / ((t_ms / 1000.0) * 1e12) + speedup = t_naive / t_ms print( - f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s" + f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" ) - if args.check: - with torch.no_grad(): - y_ref = forward_naive(x, gate, experts, args.top_k) - y_fast = y - # align dtypes for error metrics - y_ref32 = y_ref.float() - y_fast32 = y_fast.float() - diff = (y_ref32 - y_fast32).abs() - max_abs = diff.max().item() - mean_abs = diff.mean().item() - rel_l2 = ( - (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() - ) - print( - f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" - ) + with torch.no_grad(): + y_ref = forward_naive(x, gate, experts, args.top_k) + y_fast = y + y_ref32 = y_ref.float() + y_fast32 = y_fast.float() + diff = (y_ref32 - y_fast32).abs() + max_abs = diff.max().item() + mean_abs = diff.mean().item() + rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() + print( + f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" + ) else: print("hf_triton\tN/A (kernels hub not available)") - # torch_grouped placeholder — not yet implemented - print("torch_grouped\tN/A (pending implementation)") + # torch_grouped backend (PyTorch 2.8+) + try: + from axolotl.kernels.moe import torch_grouped as tg + except Exception: + tg = None + if tg is not None and tg.available(): + + def forward_tg(a, g, ex, topk): + y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk) + return y + + y_tg = forward_tg(x, gate, experts, args.top_k) + if y_tg is not None: + t_ms = bench( + forward_tg, + x, + gate, + experts, + args.top_k, + iters=args.iters, + warmup=args.warmup, + ) + tflops = flops_total / ((t_ms / 1000.0) * 1e12) + speedup = t_naive / t_ms + print( + f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" + ) + else: + print("torch_grouped\tN/A (op not callable)") + else: + print("torch_grouped\tN/A (unavailable)") if __name__ == "__main__": diff --git a/scripts/probe_torch_grouped_ops.py b/scripts/probe_torch_grouped_ops.py new file mode 100644 index 000000000..2eac1590d --- /dev/null +++ b/scripts/probe_torch_grouped_ops.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +""" +Probe PyTorch for grouped GEMM operator names and namespaces. +Run: python scripts/probe_torch_grouped_ops.py +""" + +import sys + + +def main(): + try: + import torch + except Exception as e: + print("Failed to import torch:", e) + sys.exit(1) + + print("torch version:", torch.__version__) + namespaces = [n for n in dir(torch.ops) if not n.startswith("_")] + print("ops namespaces:", namespaces) + + found_any = False + for ns in namespaces: + obj = getattr(torch.ops, ns, None) + ops = [] + if obj is not None: + try: + ops = dir(obj) + except Exception as e: + print(f"warning: failed to list ops for namespace {ns}: {e}") + cands = [ + o + for o in ops + if ("group" in o.lower()) + or ("mm_grouped" in o.lower()) + or ("matmul_grouped" in o.lower()) + or ("grouped" in o.lower()) + ] + if cands: + found_any = True + print(f"namespace {ns} candidates:", cands) + + if not found_any: + print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.") + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index ba4ed2845..d8689fba0 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -1,16 +1,141 @@ """ -Placeholder for PyTorch 2.8+ grouped GEMM MoE path. -Currently probes availability; full integration to be implemented. +PyTorch 2.8+ grouped GEMM MoE path (cuBLASLt-backed). +This is a cautious first pass that probes available ops and only runs when supported. """ from __future__ import annotations +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F + def available() -> bool: try: - import torch # noqa: F401 - ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) - return ver >= (2, 8) + if ver < (2, 8): + return False + # Check for aten grouped mm ops + return hasattr(torch.ops, "aten") and ( + hasattr(torch.ops.aten, "_grouped_mm") + or hasattr(torch.ops.aten, "_scaled_grouped_mm") + ) except Exception: return False + + +def _call_grouped_mm( + As: List[torch.Tensor], Bs: List[torch.Tensor] +) -> Optional[List[torch.Tensor]]: + """ + Try calling the appropriate grouped mm op available in this torch build. + Returns list of outputs or None on failure. + """ + try: + if hasattr(torch.ops.aten, "_grouped_mm"): + return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined] + if hasattr(torch.ops.aten, "_scaled_grouped_mm"): + # signature likely (As, Bs, alpha, beta) + return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined] + except Exception: + return None + return None + + +def moe_ffn_forward_grouped( + hidden_states, gate_linear, experts_module, top_k: int +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Attempt a grouped GEMM fast path using PyTorch 2.8+. + If unavailable or fails, returns (None, None) so caller can fallback. + """ + try: + bsz, seqlen, hdim = hidden_states.shape + x = hidden_states.view(-1, hdim) + router_logits = gate_linear(x) + + # topk routing in torch (keep simple to avoid dependency cycles) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) + topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype) + + # Build per-expert input lists + flat_idx = topk_idx.view(-1) + x_rep = x.repeat_interleave(top_k, dim=0) + + # Cache stacked weights on experts + E = experts_module.num_experts + dev, dt = x.device, x.dtype + if ( + not hasattr(experts_module, "_stacked_w1") + or experts_module._stacked_w1.device != dev + or experts_module._stacked_w1.dtype != dt + ): + w1 = [experts_module[i].w1.weight.t() for i in range(E)] + w3 = [experts_module[i].w3.weight.t() for i in range(E)] + w2 = [experts_module[i].w2.weight.t() for i in range(E)] + experts_module._stacked_w1 = ( + torch.stack(w1, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w3 = ( + torch.stack(w3, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w13 = torch.cat( + [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 + ).contiguous() + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + + # Grouped GEMM for up+gate + As: List[torch.Tensor] = [] + Bs: List[torch.Tensor] = [] + expert_slices = [] + for i in range(E): + sel = flat_idx == i + if sel.any(): + Xi = x_rep[sel] + As.append(Xi) + Bs.append(W13[i]) + expert_slices.append((i, sel)) + + if not As: + # no tokens routed — edge case + out = torch.zeros_like(x) + return out.view(bsz, seqlen, hdim), router_logits + + Y_list = _call_grouped_mm(As, Bs) + if Y_list is None: + return None, None + + # SwiGLU on each expert block and prepare for down projection + As2: List[torch.Tensor] = [] + Bs2: List[torch.Tensor] = [] + y_buf = torch.empty_like(x_rep) + # split Y into (I, I) + for (i, sel), Yi in zip(expert_slices, Y_list): + I2 = Yi.shape[-1] // 2 + Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] + As2.append(Yi_hidden) + Bs2.append(W2[i]) + + Y2_list = _call_grouped_mm(As2, Bs2) + if Y2_list is None: + return None, None + + # Write back, apply per-token weighting, and reduce over top_k + for (i, sel), Out_i in zip(expert_slices, Y2_list): + y_buf[sel] = Out_i + y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + return y.view(bsz, seqlen, hdim), router_logits + except Exception: + return None, None