From 43ada1278acd5cba6347d45041fd6fe3d524e877 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 12:20:41 -0400 Subject: [PATCH] moe kernels init scaffold --- docs/moe_backends.md | 18 +++ scripts/bench_moe.py | 161 ++++++++++++++++++++ src/axolotl/kernels/moe/__init__.py | 3 + src/axolotl/kernels/moe/backends.py | 65 ++++++++ src/axolotl/kernels/moe/hf_triton.py | 104 +++++++++++++ src/axolotl/kernels/moe/torch_grouped.py | 16 ++ src/axolotl/monkeypatch/mixtral/__init__.py | 34 ++++- 7 files changed, 397 insertions(+), 4 deletions(-) create mode 100644 docs/moe_backends.md create mode 100644 scripts/bench_moe.py create mode 100644 src/axolotl/kernels/moe/__init__.py create mode 100644 src/axolotl/kernels/moe/backends.py create mode 100644 src/axolotl/kernels/moe/hf_triton.py create mode 100644 src/axolotl/kernels/moe/torch_grouped.py diff --git a/docs/moe_backends.md b/docs/moe_backends.md new file mode 100644 index 000000000..150fa0eb0 --- /dev/null +++ b/docs/moe_backends.md @@ -0,0 +1,18 @@ +MoE Backends in Axolotl + +Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via an environment variable: + +- AXOLOTL_MOE_BACKEND=auto|hf_triton|torch_grouped|naive + +Behavior +- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive. +- hf_triton: uses the Hugging Face kernels hub (kernels-community/triton_kernels) when available. +- torch_grouped: targets PyTorch 2.8+ grouped GEMM. +- naive: keeps the reference per-expert loop. + +Notes +- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated. +- No changes to training scripts are required; Axolotl wraps Transformers Trainer; selection happens inside the model forward. + +Example +AXOLOTL_MOE_BACKEND=hf_triton accelerate launch -m axolotl.cli.train path/to/config.yaml diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py new file mode 100644 index 000000000..644d22320 --- /dev/null +++ b/scripts/bench_moe.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +import argparse +import os +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwiGLUMlp(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) + self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) + self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = F.silu + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + +class Experts(nn.Module): + def __init__(self, num_experts: int, hidden_size: int, intermediate_size: int): + super().__init__() + self.layers = nn.ModuleList( + SwiGLUMlp(hidden_size, intermediate_size) for _ in range(num_experts) + ) + self.num_experts = num_experts + + def __getitem__(self, idx): + return self.layers[idx] + + +def forward_naive( + hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int +): + bsz, seqlen, hdim = hidden_states.shape + x = hidden_states.view(-1, hdim) + router_logits = gate(x) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) + topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype) + x_rep = x.repeat_interleave(top_k, dim=0) + y = torch.empty_like(x_rep) + flat_idx = topk_idx.view(-1) + for i in range(experts.num_experts): + sel = flat_idx == i + if sel.any(): + y[sel] = experts[i](x_rep[sel]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + return y.view(bsz, seqlen, hdim) + + +def forward_hf_triton( + hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int +): + try: + from axolotl.kernels.moe import hf_triton as _hf + except Exception: + return None + try: + y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k) + return y + except Exception: + return None + + +def bench(fn, *args, iters=50, warmup=10, sync=True): + # warmup + for _ in range(warmup): + out = fn(*args) + if sync and torch.cuda.is_available(): + torch.cuda.synchronize() + # measure + times = [] + for _ in range(iters): + if sync and torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.perf_counter() + out = fn(*args) + if sync and torch.cuda.is_available(): + torch.cuda.synchronize() + dt = (time.perf_counter() - t0) * 1000.0 + times.append(dt) + return sum(times) / len(times) + + +def main(): + p = argparse.ArgumentParser(description="MoE microbenchmark") + p.add_argument("--bsz", type=int, default=8) + p.add_argument("--seq", type=int, default=1024) + p.add_argument("--hidden", type=int, default=4096) + p.add_argument("--inter", type=int, default=14336) + p.add_argument("--experts", type=int, default=8) + p.add_argument("--top_k", type=int, default=2) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] + ) + p.add_argument("--iters", type=int, default=50) + p.add_argument("--warmup", type=int, default=10) + args = p.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + }[args.dtype] + + torch.manual_seed(0) + if device == "cuda": + torch.cuda.manual_seed(0) + + # Model + experts = Experts(args.experts, args.hidden, args.inter).to( + device=device, dtype=dtype + ) + gate = nn.Linear(args.hidden, args.experts, bias=False).to( + device=device, dtype=dtype + ) + + # data + x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) + + # Report config + tokens = args.bsz * args.seq + print( + f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} experts={args.experts} top_k={args.top_k}" + ) + + # Naive baseline + t_naive = bench( + forward_naive, + x, + gate, + experts, + args.top_k, + iters=args.iters, + warmup=args.warmup, + ) + print(f"naive {t_naive:.2f} ms {tokens / (t_naive / 1000):.1f} tok/s") + + # HF Triton (routing + stub compute for now) + os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton") + t_hf = forward_hf_triton + y = t_hf(x, gate, experts, args.top_k) + if y is not None: + t_ms = bench( + t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup + ) + print(f"hf_triton {t_ms:.2f} ms {tokens / (t_ms / 1000):.1f} tok/s") + else: + print("hf_triton N/A (kernels hub not available)") + + # torch_grouped placeholder — not yet implemented + print("torch_grouped N/A (pending implementation)") + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/kernels/moe/__init__.py b/src/axolotl/kernels/moe/__init__.py new file mode 100644 index 000000000..e1d16d8fa --- /dev/null +++ b/src/axolotl/kernels/moe/__init__.py @@ -0,0 +1,3 @@ +from .backends import MOEBackend, get_moe_backend_name + +__all__ = ["get_moe_backend_name", "MOEBackend"] diff --git a/src/axolotl/kernels/moe/backends.py b/src/axolotl/kernels/moe/backends.py new file mode 100644 index 000000000..c625718a2 --- /dev/null +++ b/src/axolotl/kernels/moe/backends.py @@ -0,0 +1,65 @@ +import os +import warnings +from enum import Enum + + +class MOEBackend(str, Enum): + AUTO = "auto" + HF_TRITON = "hf_triton" + TORCH_GROUPED = "torch_grouped" + NAIVE = "naive" + + +def _probe_torch_grouped() -> bool: + try: + import torch # noqa: F401 + + # Prefer a simple version check; exact APIs may vary across 2.8+. + ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) + return ver >= (2, 8) + except Exception: + return False + + +def _probe_hf_triton() -> bool: + try: + # The hub loads kernels lazily; this import is a light probe. + import importlib + + importlib.import_module("kernels") + return True + except Exception: + return False + + +def get_moe_backend_name(preferred: str | None = None) -> MOEBackend: + """ + Resolve the desired MoE backend using, in order of precedence: + - explicit preferred argument + - environment variable AXOLOTL_MOE_BACKEND + - auto detection + """ + choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower() + try: + selected = MOEBackend(choice) + except ValueError: + warnings.warn(f"Unknown moe backend '{choice}', falling back to auto") + selected = MOEBackend.AUTO + + if selected == MOEBackend.AUTO: + if _probe_torch_grouped(): + return MOEBackend.TORCH_GROUPED + if _probe_hf_triton(): + return MOEBackend.HF_TRITON + return MOEBackend.NAIVE + if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped(): + warnings.warn( + "torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive" + ) + return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE + if selected == MOEBackend.HF_TRITON and not _probe_hf_triton(): + warnings.warn( + "hf_triton requested but kernels hub not available; falling back to torch_grouped/naive" + ) + return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE + return selected diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py new file mode 100644 index 000000000..b0fbd29a5 --- /dev/null +++ b/src/axolotl/kernels/moe/hf_triton.py @@ -0,0 +1,104 @@ +""" +Adapter for Hugging Face kernels hub (kernels-community/triton_kernels). +This file provides light probes and placeholders for future integration. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Tuple + + +@dataclass +class HFTritonHandles: + routing: Any + matmul_ogs: Any + swiglu: Any + + +def available() -> bool: + try: + import kernels # noqa: F401 + + return True + except Exception: + return False + + +def load() -> Optional[HFTritonHandles]: + try: + from kernels import get_kernel + + tk = get_kernel("kernels-community/triton_kernels") + return HFTritonHandles( + routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu + ) + except Exception: + return None + + +def route_topk(logits, top_k: int): + handles = load() + if handles is None: + return None + return handles.routing.routing_torch(logits, n_expts_act=top_k) + + +def swiglu(x, alpha, limit=1.0, routing_data=None): + handles = load() + if handles is None: + return None + pc = handles.swiglu.PrecisionConfig(limit=limit) + return handles.swiglu.swiglu(x, alpha, pc, routing_data) + + +def moe_ffn_forward_stub( + hidden_states, gate_linear, experts_module, top_k: int +) -> Tuple[object, object]: + """ + Temporary stub that uses kernels hub routing, but falls back to per-expert compute. + Returns (final_hidden_states, router_logits). + """ + import torch + import torch.nn.functional as F + + bsz, seqlen, hdim = hidden_states.shape + flat = hidden_states.view(-1, hdim) + router_logits = gate_linear(flat) + # use hub routing if available; otherwise fallback to softmax+topk + routed = None + if available(): + try: + routed = route_topk(router_logits, top_k) + except Exception: + routed = None + if routed is None: + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) + topk_weight /= topk_weight.sum(dim=-1, keepdim=True) + topk_weight = topk_weight.to(flat.dtype) + x_rep = flat.repeat_interleave(top_k, dim=0) + y = torch.empty_like(x_rep) + flat_idx = topk_idx.view(-1) + for i in range(experts_module.num_experts): + expert = experts_module[i] + y[flat_idx == i] = expert(x_rep[flat_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + return y.reshape(bsz, seqlen, hdim), router_logits + # If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired. + ex_routing_data, gather_idx, scatter_idx = routed + # Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu) + # For now, reconstruct the same result as naive path (no speedup but validates routing). + # We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k. + # Simpler: reuse naive computation for correctness; detailed integration will follow. + routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) + topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype) + x_rep = flat.repeat_interleave(top_k, dim=0) + y = torch.empty_like(x_rep) + flat_idx = topk_idx.view(-1) + for i in range(experts_module.num_experts): + expert = experts_module[i] + y[flat_idx == i] = expert(x_rep[flat_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + return y.reshape(bsz, seqlen, hdim), router_logits diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py new file mode 100644 index 000000000..ba4ed2845 --- /dev/null +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -0,0 +1,16 @@ +""" +Placeholder for PyTorch 2.8+ grouped GEMM MoE path. +Currently probes availability; full integration to be implemented. +""" + +from __future__ import annotations + + +def available() -> bool: + try: + import torch # noqa: F401 + + ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) + return ver >= (2, 8) + except Exception: + return False diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index b353b12cf..988a1c6f9 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -6,8 +6,13 @@ import torch def patch_mixtral_moe_forward_zero3() -> None: + import warnings + import torch.nn.functional as F + from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton + from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name + def mlp_forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( hidden_states @@ -21,21 +26,42 @@ def patch_mixtral_moe_forward_zero3() -> None: hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) + backend = get_moe_backend_name() + if backend == MOEBackend.HF_TRITON and _hf_triton.available(): + # Stub path: use kernels hub routing and fallback per-expert compute + try: + final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub( + hidden_states.view(batch_size, sequence_length, hidden_dim), + self.gate, + self.experts, + self.top_k, + ) + return final_hidden_states, router_logits + except Exception as e: + warnings.warn(f"hf_triton backend failed, falling back to naive: {e}") + elif ( + backend == MOEBackend.TORCH_GROUPED + and not _moe_backends._probe_torch_grouped() + ): + warnings.warn( + "torch_grouped selected but not available; falling back to naive" + ) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk( routing_weights, self.top_k, dim=-1, sorted=False ) topk_weight /= topk_weight.sum(dim=-1, keepdim=True) - # we cast back to the input dtype topk_weight = topk_weight.to(hidden_states.dtype) - hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) - y = torch.empty_like(hidden_states) + hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states_rep) flat_topk_idx = topk_idx.view(-1) for i in range(self.num_experts): expert = self.experts[i] - y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + sel = flat_topk_idx == i + if sel.any(): + y[sel] = expert(hidden_states_rep[sel]) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits