just grouped_mm for now
This commit is contained in:
@@ -2,18 +2,17 @@ MoE Backends in Axolotl
|
||||
|
||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
||||
|
||||
- Set `moe_backend: auto|hf_triton|torch_grouped|naive`
|
||||
- Set `moe_backend: auto|torch_grouped|naive`
|
||||
|
||||
Behavior
|
||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
|
||||
- hf_triton: uses the Hugging Face kernels hub (kernels-community/triton_kernels) when available.
|
||||
- torch_grouped: targets PyTorch 2.8+ grouped GEMM.
|
||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
|
||||
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
|
||||
- naive: keeps the reference per-expert loop.
|
||||
|
||||
Notes
|
||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
|
||||
- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used.
|
||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
|
||||
- No changes to training scripts are required; selection happens inside the model forward.
|
||||
|
||||
Example
|
||||
moe_backend: hf_triton
|
||||
moe_backend: torch_grouped
|
||||
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
@@ -52,20 +51,6 @@ def forward_naive(
|
||||
return y.view(bsz, seqlen, hdim)
|
||||
|
||||
|
||||
def forward_hf_triton(
|
||||
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
|
||||
):
|
||||
try:
|
||||
from axolotl.kernels.moe import hf_triton as _hf
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k)
|
||||
return y
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||
# warmup
|
||||
for _ in range(warmup):
|
||||
@@ -159,33 +144,6 @@ def main():
|
||||
with torch.no_grad():
|
||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||
|
||||
# HF Triton
|
||||
t_hf = forward_hf_triton
|
||||
y = t_hf(x, gate, experts, args.top_k)
|
||||
if y is not None:
|
||||
t_ms = bench(
|
||||
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
|
||||
)
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_ms
|
||||
print(
|
||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
# parity for hf_triton vs naive
|
||||
with torch.no_grad():
|
||||
y_fast = y
|
||||
y_ref32 = y_ref.float()
|
||||
y_fast32 = y_fast.float()
|
||||
diff = (y_ref32 - y_fast32).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||
print(
|
||||
f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
else:
|
||||
print("hf_triton\tN/A (kernels hub not available)")
|
||||
|
||||
# torch_grouped backend (PyTorch 2.8+)
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
|
||||
@@ -4,7 +4,6 @@ from enum import Enum
|
||||
|
||||
class MOEBackend(str, Enum):
|
||||
AUTO = "auto"
|
||||
HF_TRITON = "hf_triton"
|
||||
TORCH_GROUPED = "torch_grouped"
|
||||
NAIVE = "naive"
|
||||
|
||||
@@ -20,17 +19,6 @@ def _probe_torch_grouped() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _probe_hf_triton() -> bool:
|
||||
try:
|
||||
# The hub loads kernels lazily; this import is a light probe.
|
||||
import importlib
|
||||
|
||||
importlib.import_module("kernels")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||
"""
|
||||
Resolve the desired MoE backend using, in order of precedence:
|
||||
@@ -47,17 +35,10 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||
if selected == MOEBackend.AUTO:
|
||||
if _probe_torch_grouped():
|
||||
return MOEBackend.TORCH_GROUPED
|
||||
if _probe_hf_triton():
|
||||
return MOEBackend.HF_TRITON
|
||||
return MOEBackend.NAIVE
|
||||
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
||||
warnings.warn(
|
||||
"torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive"
|
||||
"torch_grouped requested but torch>=2.8 not detected; falling back to naive"
|
||||
)
|
||||
return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE
|
||||
if selected == MOEBackend.HF_TRITON and not _probe_hf_triton():
|
||||
warnings.warn(
|
||||
"hf_triton requested but kernels hub not available; falling back to torch_grouped/naive"
|
||||
)
|
||||
return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE
|
||||
return MOEBackend.NAIVE
|
||||
return selected
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
Adapter for Hugging Face kernels hub (kernels-community/triton_kernels).
|
||||
This file provides light probes and placeholders for future integration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class HFTritonHandles:
|
||||
routing: Any
|
||||
matmul_ogs: Any
|
||||
swiglu: Any
|
||||
|
||||
|
||||
def available() -> bool:
|
||||
try:
|
||||
import kernels # noqa: F401
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Cache loaded handles so we don't trigger repeated hub fetches
|
||||
_CACHED_HANDLES: Optional[HFTritonHandles] = None
|
||||
_LOAD_ATTEMPTED: bool = False
|
||||
|
||||
|
||||
def load() -> Optional[HFTritonHandles]:
|
||||
global _CACHED_HANDLES, _LOAD_ATTEMPTED
|
||||
if _CACHED_HANDLES is not None:
|
||||
return _CACHED_HANDLES
|
||||
if _LOAD_ATTEMPTED:
|
||||
# Previously failed; avoid spamming retries per call
|
||||
return None
|
||||
_LOAD_ATTEMPTED = True
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
tk = get_kernel("kernels-community/triton_kernels")
|
||||
_CACHED_HANDLES = HFTritonHandles(
|
||||
routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu
|
||||
)
|
||||
return _CACHED_HANDLES
|
||||
except Exception:
|
||||
# Keep None in cache state to prevent repeated fetch attempts
|
||||
return None
|
||||
|
||||
|
||||
def route_topk(logits, top_k: int):
|
||||
handles = load()
|
||||
if handles is None:
|
||||
return None
|
||||
return handles.routing.routing_torch(logits, n_expts_act=top_k)
|
||||
|
||||
|
||||
def swiglu(x, alpha, limit=1.0, routing_data=None):
|
||||
handles = load()
|
||||
if handles is None:
|
||||
return None
|
||||
pc = handles.swiglu.PrecisionConfig(limit=limit)
|
||||
return handles.swiglu.swiglu(x, alpha, pc, routing_data)
|
||||
|
||||
|
||||
def moe_ffn_forward_stub(
|
||||
hidden_states, gate_linear, experts_module, top_k: int
|
||||
) -> Tuple[object, object]:
|
||||
"""
|
||||
Temporary stub that uses kernels hub routing, but falls back to per-expert compute.
|
||||
Returns (final_hidden_states, router_logits).
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
flat = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(flat)
|
||||
# Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down.
|
||||
handles = load()
|
||||
if handles is not None:
|
||||
try:
|
||||
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
||||
router_logits, n_expts_act=top_k
|
||||
)
|
||||
# Prepare and cache expert weights: shapes [E, K, N]
|
||||
import torch
|
||||
|
||||
E = experts_module.num_experts
|
||||
dev = flat.device
|
||||
dt = flat.dtype
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
W1 = []
|
||||
W3 = []
|
||||
W2 = []
|
||||
for i in range(E):
|
||||
exp = experts_module[i]
|
||||
W1.append(exp.w1.weight.t())
|
||||
W3.append(exp.w3.weight.t())
|
||||
W2.append(exp.w2.weight.t())
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(W1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(W3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(W2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
W1 = experts_module._stacked_w1
|
||||
W3 = experts_module._stacked_w3
|
||||
# Fused up+gate: single matmul on concatenated weights [E, H, 2I]
|
||||
W13 = getattr(experts_module, "_stacked_w13", None)
|
||||
if (
|
||||
W13 is None
|
||||
or W13.device != dev
|
||||
or W13.dtype != dt
|
||||
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
|
||||
):
|
||||
W13 = torch.cat([W1, W3], dim=-1).contiguous()
|
||||
experts_module._stacked_w13 = W13
|
||||
Y13 = handles.matmul_ogs.matmul_ogs(
|
||||
flat,
|
||||
W13,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=None,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
)
|
||||
# Use kernels hub SwiGLU for optimal MoE launch
|
||||
sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
|
||||
Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
|
||||
# Down projection weights [E, inter, hidden]
|
||||
W2 = experts_module._stacked_w2
|
||||
# Down matmul with fused scatter back using scatter_indx
|
||||
Out = handles.matmul_ogs.matmul_ogs(
|
||||
Hidden,
|
||||
W2,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=None,
|
||||
scatter_indx=scatter_idx,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
gammas=routing_data.gate_scal,
|
||||
)
|
||||
return Out.view(bsz, seqlen, hdim), router_logits
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback naive path for correctness
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
topk_weight = topk_weight.to(flat.dtype)
|
||||
x_rep = flat.repeat_interleave(top_k, dim=0)
|
||||
y = torch.empty_like(x_rep)
|
||||
flat_idx = topk_idx.view(-1)
|
||||
for i in range(experts_module.num_experts):
|
||||
expert = experts_module[i]
|
||||
sel = flat_idx == i
|
||||
if sel.any():
|
||||
y[sel] = expert(x_rep[sel])
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
return y.reshape(bsz, seqlen, hdim), router_logits
|
||||
@@ -10,7 +10,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton
|
||||
from axolotl.kernels.moe import backends as _moe_backends
|
||||
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||
|
||||
def mlp_forward(self, hidden_states):
|
||||
@@ -28,19 +28,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||
backend = get_moe_backend_name(preferred)
|
||||
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
|
||||
# Stub path: use kernels hub routing and fallback per-expert compute
|
||||
try:
|
||||
final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub(
|
||||
hidden_states.view(batch_size, sequence_length, hidden_dim),
|
||||
self.gate,
|
||||
self.experts,
|
||||
self.top_k,
|
||||
)
|
||||
return final_hidden_states, router_logits
|
||||
except Exception as e:
|
||||
warnings.warn(f"hf_triton backend failed, falling back to naive: {e}")
|
||||
elif (
|
||||
if (
|
||||
backend == MOEBackend.TORCH_GROUPED
|
||||
and not _moe_backends._probe_torch_grouped()
|
||||
):
|
||||
@@ -73,4 +61,23 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
)
|
||||
|
||||
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
# Wrap forward to support optional torch_grouped backend via config
|
||||
from axolotl.kernels.moe import torch_grouped as _tg
|
||||
|
||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||
backend = get_moe_backend_name(preferred)
|
||||
|
||||
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
|
||||
|
||||
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||
hidden_states, self.gate, self.experts, self.top_k
|
||||
)
|
||||
if y is None:
|
||||
return moe_forward(self, hidden_states)
|
||||
return y, router_logits
|
||||
|
||||
MixtralSparseMoeBlock.forward = moe_forward_grouped
|
||||
else:
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
|
||||
@@ -133,10 +133,10 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(),
|
||||
)
|
||||
moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field(
|
||||
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user