just grouped_mm for now

This commit is contained in:
Dan Saunders
2025-09-15 23:03:18 -04:00
parent 773d7e4291
commit 7d572b58d1
6 changed files with 32 additions and 264 deletions

View File

@@ -2,18 +2,17 @@ MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- Set `moe_backend: auto|hf_triton|torch_grouped|naive`
- Set `moe_backend: auto|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
- hf_triton: uses the Hugging Face kernels hub (kernels-community/triton_kernels) when available.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM.
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
- naive: keeps the reference per-expert loop.
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used.
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
- No changes to training scripts are required; selection happens inside the model forward.
Example
moe_backend: hf_triton
moe_backend: torch_grouped
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python
import argparse
import os
import time
import torch
@@ -52,20 +51,6 @@ def forward_naive(
return y.view(bsz, seqlen, hdim)
def forward_hf_triton(
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
):
try:
from axolotl.kernels.moe import hf_triton as _hf
except Exception:
return None
try:
y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k)
return y
except Exception:
return None
def bench(fn, *args, iters=50, warmup=10, sync=True):
# warmup
for _ in range(warmup):
@@ -159,33 +144,6 @@ def main():
with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
# HF Triton
t_hf = forward_hf_triton
y = t_hf(x, gate, experts, args.top_k)
if y is not None:
t_ms = bench(
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
)
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
speedup = t_naive / t_ms
print(
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
)
# parity for hf_triton vs naive
with torch.no_grad():
y_fast = y
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
print(
f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
print("hf_triton\tN/A (kernels hub not available)")
# torch_grouped backend (PyTorch 2.8+)
try:
from axolotl.kernels.moe import torch_grouped as tg

View File

@@ -4,7 +4,6 @@ from enum import Enum
class MOEBackend(str, Enum):
AUTO = "auto"
HF_TRITON = "hf_triton"
TORCH_GROUPED = "torch_grouped"
NAIVE = "naive"
@@ -20,17 +19,6 @@ def _probe_torch_grouped() -> bool:
return False
def _probe_hf_triton() -> bool:
try:
# The hub loads kernels lazily; this import is a light probe.
import importlib
importlib.import_module("kernels")
return True
except Exception:
return False
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
@@ -47,17 +35,10 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
if selected == MOEBackend.AUTO:
if _probe_torch_grouped():
return MOEBackend.TORCH_GROUPED
if _probe_hf_triton():
return MOEBackend.HF_TRITON
return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive"
"torch_grouped requested but torch>=2.8 not detected; falling back to naive"
)
return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE
if selected == MOEBackend.HF_TRITON and not _probe_hf_triton():
warnings.warn(
"hf_triton requested but kernels hub not available; falling back to torch_grouped/naive"
)
return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE
return MOEBackend.NAIVE
return selected

View File

@@ -1,177 +0,0 @@
"""
Adapter for Hugging Face kernels hub (kernels-community/triton_kernels).
This file provides light probes and placeholders for future integration.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional, Tuple
@dataclass
class HFTritonHandles:
routing: Any
matmul_ogs: Any
swiglu: Any
def available() -> bool:
try:
import kernels # noqa: F401
return True
except Exception:
return False
# Cache loaded handles so we don't trigger repeated hub fetches
_CACHED_HANDLES: Optional[HFTritonHandles] = None
_LOAD_ATTEMPTED: bool = False
def load() -> Optional[HFTritonHandles]:
global _CACHED_HANDLES, _LOAD_ATTEMPTED
if _CACHED_HANDLES is not None:
return _CACHED_HANDLES
if _LOAD_ATTEMPTED:
# Previously failed; avoid spamming retries per call
return None
_LOAD_ATTEMPTED = True
try:
from kernels import get_kernel
tk = get_kernel("kernels-community/triton_kernels")
_CACHED_HANDLES = HFTritonHandles(
routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu
)
return _CACHED_HANDLES
except Exception:
# Keep None in cache state to prevent repeated fetch attempts
return None
def route_topk(logits, top_k: int):
handles = load()
if handles is None:
return None
return handles.routing.routing_torch(logits, n_expts_act=top_k)
def swiglu(x, alpha, limit=1.0, routing_data=None):
handles = load()
if handles is None:
return None
pc = handles.swiglu.PrecisionConfig(limit=limit)
return handles.swiglu.swiglu(x, alpha, pc, routing_data)
def moe_ffn_forward_stub(
hidden_states, gate_linear, experts_module, top_k: int
) -> Tuple[object, object]:
"""
Temporary stub that uses kernels hub routing, but falls back to per-expert compute.
Returns (final_hidden_states, router_logits).
"""
import torch
import torch.nn.functional as F
bsz, seqlen, hdim = hidden_states.shape
flat = hidden_states.view(-1, hdim)
router_logits = gate_linear(flat)
# Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down.
handles = load()
if handles is not None:
try:
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
router_logits, n_expts_act=top_k
)
# Prepare and cache expert weights: shapes [E, K, N]
import torch
E = experts_module.num_experts
dev = flat.device
dt = flat.dtype
if (
not hasattr(experts_module, "_stacked_w1")
or experts_module._stacked_w1.device != dev
or experts_module._stacked_w1.dtype != dt
):
W1 = []
W3 = []
W2 = []
for i in range(E):
exp = experts_module[i]
W1.append(exp.w1.weight.t())
W3.append(exp.w3.weight.t())
W2.append(exp.w2.weight.t())
experts_module._stacked_w1 = (
torch.stack(W1, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
experts_module._stacked_w3 = (
torch.stack(W3, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
experts_module._stacked_w2 = (
torch.stack(W2, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
W1 = experts_module._stacked_w1
W3 = experts_module._stacked_w3
# Fused up+gate: single matmul on concatenated weights [E, H, 2I]
W13 = getattr(experts_module, "_stacked_w13", None)
if (
W13 is None
or W13.device != dev
or W13.dtype != dt
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
):
W13 = torch.cat([W1, W3], dim=-1).contiguous()
experts_module._stacked_w13 = W13
Y13 = handles.matmul_ogs.matmul_ogs(
flat,
W13,
None,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
# Use kernels hub SwiGLU for optimal MoE launch
sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
# Down projection weights [E, inter, hidden]
W2 = experts_module._stacked_w2
# Down matmul with fused scatter back using scatter_indx
Out = handles.matmul_ogs.matmul_ogs(
Hidden,
W2,
None,
routing_data=routing_data,
gather_indx=None,
scatter_indx=scatter_idx,
precision_config=handles.matmul_ogs.PrecisionConfig(),
gammas=routing_data.gate_scal,
)
return Out.view(bsz, seqlen, hdim), router_logits
except Exception:
pass
# Fallback naive path for correctness
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.to(flat.dtype)
x_rep = flat.repeat_interleave(top_k, dim=0)
y = torch.empty_like(x_rep)
flat_idx = topk_idx.view(-1)
for i in range(experts_module.num_experts):
expert = experts_module[i]
sel = flat_idx == i
if sel.any():
y[sel] = expert(x_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
return y.reshape(bsz, seqlen, hdim), router_logits

View File

@@ -10,7 +10,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
import torch.nn.functional as F
from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton
from axolotl.kernels.moe import backends as _moe_backends
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
def mlp_forward(self, hidden_states):
@@ -28,19 +28,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
router_logits = self.gate(hidden_states)
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
# Stub path: use kernels hub routing and fallback per-expert compute
try:
final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub(
hidden_states.view(batch_size, sequence_length, hidden_dim),
self.gate,
self.experts,
self.top_k,
)
return final_hidden_states, router_logits
except Exception as e:
warnings.warn(f"hf_triton backend failed, falling back to naive: {e}")
elif (
if (
backend == MOEBackend.TORCH_GROUPED
and not _moe_backends._probe_torch_grouped()
):
@@ -73,4 +61,23 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
)
MixtralBlockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
# Wrap forward to support optional torch_grouped backend via config
from axolotl.kernels.moe import torch_grouped as _tg
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
bsz, seqlen, hdim = hidden_states.shape
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
if y is None:
return moe_forward(self, hidden_states)
return y, router_logits
MixtralSparseMoeBlock.forward = moe_forward_grouped
else:
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -133,10 +133,10 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field(
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
default=None,
json_schema_extra={
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
},
)