moe kernels init scaffold

This commit is contained in:
Dan Saunders
2025-09-15 12:20:41 -04:00
parent 4065bc14c6
commit 43ada1278a
7 changed files with 397 additions and 4 deletions

View File

@@ -0,0 +1,3 @@
from .backends import MOEBackend, get_moe_backend_name
__all__ = ["get_moe_backend_name", "MOEBackend"]

View File

@@ -0,0 +1,65 @@
import os
import warnings
from enum import Enum
class MOEBackend(str, Enum):
AUTO = "auto"
HF_TRITON = "hf_triton"
TORCH_GROUPED = "torch_grouped"
NAIVE = "naive"
def _probe_torch_grouped() -> bool:
try:
import torch # noqa: F401
# Prefer a simple version check; exact APIs may vary across 2.8+.
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
return ver >= (2, 8)
except Exception:
return False
def _probe_hf_triton() -> bool:
try:
# The hub loads kernels lazily; this import is a light probe.
import importlib
importlib.import_module("kernels")
return True
except Exception:
return False
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
- explicit preferred argument
- environment variable AXOLOTL_MOE_BACKEND
- auto detection
"""
choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower()
try:
selected = MOEBackend(choice)
except ValueError:
warnings.warn(f"Unknown moe backend '{choice}', falling back to auto")
selected = MOEBackend.AUTO
if selected == MOEBackend.AUTO:
if _probe_torch_grouped():
return MOEBackend.TORCH_GROUPED
if _probe_hf_triton():
return MOEBackend.HF_TRITON
return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive"
)
return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE
if selected == MOEBackend.HF_TRITON and not _probe_hf_triton():
warnings.warn(
"hf_triton requested but kernels hub not available; falling back to torch_grouped/naive"
)
return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE
return selected

View File

@@ -0,0 +1,104 @@
"""
Adapter for Hugging Face kernels hub (kernels-community/triton_kernels).
This file provides light probes and placeholders for future integration.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional, Tuple
@dataclass
class HFTritonHandles:
routing: Any
matmul_ogs: Any
swiglu: Any
def available() -> bool:
try:
import kernels # noqa: F401
return True
except Exception:
return False
def load() -> Optional[HFTritonHandles]:
try:
from kernels import get_kernel
tk = get_kernel("kernels-community/triton_kernels")
return HFTritonHandles(
routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu
)
except Exception:
return None
def route_topk(logits, top_k: int):
handles = load()
if handles is None:
return None
return handles.routing.routing_torch(logits, n_expts_act=top_k)
def swiglu(x, alpha, limit=1.0, routing_data=None):
handles = load()
if handles is None:
return None
pc = handles.swiglu.PrecisionConfig(limit=limit)
return handles.swiglu.swiglu(x, alpha, pc, routing_data)
def moe_ffn_forward_stub(
hidden_states, gate_linear, experts_module, top_k: int
) -> Tuple[object, object]:
"""
Temporary stub that uses kernels hub routing, but falls back to per-expert compute.
Returns (final_hidden_states, router_logits).
"""
import torch
import torch.nn.functional as F
bsz, seqlen, hdim = hidden_states.shape
flat = hidden_states.view(-1, hdim)
router_logits = gate_linear(flat)
# use hub routing if available; otherwise fallback to softmax+topk
routed = None
if available():
try:
routed = route_topk(router_logits, top_k)
except Exception:
routed = None
if routed is None:
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.to(flat.dtype)
x_rep = flat.repeat_interleave(top_k, dim=0)
y = torch.empty_like(x_rep)
flat_idx = topk_idx.view(-1)
for i in range(experts_module.num_experts):
expert = experts_module[i]
y[flat_idx == i] = expert(x_rep[flat_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
return y.reshape(bsz, seqlen, hdim), router_logits
# If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired.
ex_routing_data, gather_idx, scatter_idx = routed
# Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu)
# For now, reconstruct the same result as naive path (no speedup but validates routing).
# We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k.
# Simpler: reuse naive computation for correctness; detailed integration will follow.
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype)
x_rep = flat.repeat_interleave(top_k, dim=0)
y = torch.empty_like(x_rep)
flat_idx = topk_idx.view(-1)
for i in range(experts_module.num_experts):
expert = experts_module[i]
y[flat_idx == i] = expert(x_rep[flat_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
return y.reshape(bsz, seqlen, hdim), router_logits

View File

@@ -0,0 +1,16 @@
"""
Placeholder for PyTorch 2.8+ grouped GEMM MoE path.
Currently probes availability; full integration to be implemented.
"""
from __future__ import annotations
def available() -> bool:
try:
import torch # noqa: F401
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
return ver >= (2, 8)
except Exception:
return False

View File

@@ -6,8 +6,13 @@ import torch
def patch_mixtral_moe_forward_zero3() -> None:
import warnings
import torch.nn.functional as F
from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
@@ -21,21 +26,42 @@ def patch_mixtral_moe_forward_zero3() -> None:
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
backend = get_moe_backend_name()
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
# Stub path: use kernels hub routing and fallback per-expert compute
try:
final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub(
hidden_states.view(batch_size, sequence_length, hidden_dim),
self.gate,
self.experts,
self.top_k,
)
return final_hidden_states, router_logits
except Exception as e:
warnings.warn(f"hf_triton backend failed, falling back to naive: {e}")
elif (
backend == MOEBackend.TORCH_GROUPED
and not _moe_backends._probe_torch_grouped()
):
warnings.warn(
"torch_grouped selected but not available; falling back to naive"
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states_rep)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
sel = flat_topk_idx == i
if sel.any():
y[sel] = expert(hidden_states_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits