just grouped_mm for now
This commit is contained in:
@@ -2,18 +2,17 @@ MoE Backends in Axolotl
|
|||||||
|
|
||||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
||||||
|
|
||||||
- Set `moe_backend: auto|hf_triton|torch_grouped|naive`
|
- Set `moe_backend: auto|torch_grouped|naive`
|
||||||
|
|
||||||
Behavior
|
Behavior
|
||||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
|
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
|
||||||
- hf_triton: uses the Hugging Face kernels hub (kernels-community/triton_kernels) when available.
|
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
|
||||||
- torch_grouped: targets PyTorch 2.8+ grouped GEMM.
|
|
||||||
- naive: keeps the reference per-expert loop.
|
- naive: keeps the reference per-expert loop.
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
|
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
|
||||||
- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used.
|
- No changes to training scripts are required; selection happens inside the model forward.
|
||||||
|
|
||||||
Example
|
Example
|
||||||
moe_backend: hf_triton
|
moe_backend: torch_grouped
|
||||||
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -52,20 +51,6 @@ def forward_naive(
|
|||||||
return y.view(bsz, seqlen, hdim)
|
return y.view(bsz, seqlen, hdim)
|
||||||
|
|
||||||
|
|
||||||
def forward_hf_triton(
|
|
||||||
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from axolotl.kernels.moe import hf_triton as _hf
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k)
|
|
||||||
return y
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||||
# warmup
|
# warmup
|
||||||
for _ in range(warmup):
|
for _ in range(warmup):
|
||||||
@@ -159,33 +144,6 @@ def main():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||||
|
|
||||||
# HF Triton
|
|
||||||
t_hf = forward_hf_triton
|
|
||||||
y = t_hf(x, gate, experts, args.top_k)
|
|
||||||
if y is not None:
|
|
||||||
t_ms = bench(
|
|
||||||
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
|
|
||||||
)
|
|
||||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
|
||||||
speedup = t_naive / t_ms
|
|
||||||
print(
|
|
||||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
|
||||||
)
|
|
||||||
# parity for hf_triton vs naive
|
|
||||||
with torch.no_grad():
|
|
||||||
y_fast = y
|
|
||||||
y_ref32 = y_ref.float()
|
|
||||||
y_fast32 = y_fast.float()
|
|
||||||
diff = (y_ref32 - y_fast32).abs()
|
|
||||||
max_abs = diff.max().item()
|
|
||||||
mean_abs = diff.mean().item()
|
|
||||||
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
|
||||||
print(
|
|
||||||
f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("hf_triton\tN/A (kernels hub not available)")
|
|
||||||
|
|
||||||
# torch_grouped backend (PyTorch 2.8+)
|
# torch_grouped backend (PyTorch 2.8+)
|
||||||
try:
|
try:
|
||||||
from axolotl.kernels.moe import torch_grouped as tg
|
from axolotl.kernels.moe import torch_grouped as tg
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from enum import Enum
|
|||||||
|
|
||||||
class MOEBackend(str, Enum):
|
class MOEBackend(str, Enum):
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
HF_TRITON = "hf_triton"
|
|
||||||
TORCH_GROUPED = "torch_grouped"
|
TORCH_GROUPED = "torch_grouped"
|
||||||
NAIVE = "naive"
|
NAIVE = "naive"
|
||||||
|
|
||||||
@@ -20,17 +19,6 @@ def _probe_torch_grouped() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _probe_hf_triton() -> bool:
|
|
||||||
try:
|
|
||||||
# The hub loads kernels lazily; this import is a light probe.
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
importlib.import_module("kernels")
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||||
"""
|
"""
|
||||||
Resolve the desired MoE backend using, in order of precedence:
|
Resolve the desired MoE backend using, in order of precedence:
|
||||||
@@ -47,17 +35,10 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
|||||||
if selected == MOEBackend.AUTO:
|
if selected == MOEBackend.AUTO:
|
||||||
if _probe_torch_grouped():
|
if _probe_torch_grouped():
|
||||||
return MOEBackend.TORCH_GROUPED
|
return MOEBackend.TORCH_GROUPED
|
||||||
if _probe_hf_triton():
|
|
||||||
return MOEBackend.HF_TRITON
|
|
||||||
return MOEBackend.NAIVE
|
return MOEBackend.NAIVE
|
||||||
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive"
|
"torch_grouped requested but torch>=2.8 not detected; falling back to naive"
|
||||||
)
|
)
|
||||||
return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE
|
return MOEBackend.NAIVE
|
||||||
if selected == MOEBackend.HF_TRITON and not _probe_hf_triton():
|
|
||||||
warnings.warn(
|
|
||||||
"hf_triton requested but kernels hub not available; falling back to torch_grouped/naive"
|
|
||||||
)
|
|
||||||
return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE
|
|
||||||
return selected
|
return selected
|
||||||
|
|||||||
@@ -1,177 +0,0 @@
|
|||||||
"""
|
|
||||||
Adapter for Hugging Face kernels hub (kernels-community/triton_kernels).
|
|
||||||
This file provides light probes and placeholders for future integration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HFTritonHandles:
|
|
||||||
routing: Any
|
|
||||||
matmul_ogs: Any
|
|
||||||
swiglu: Any
|
|
||||||
|
|
||||||
|
|
||||||
def available() -> bool:
|
|
||||||
try:
|
|
||||||
import kernels # noqa: F401
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Cache loaded handles so we don't trigger repeated hub fetches
|
|
||||||
_CACHED_HANDLES: Optional[HFTritonHandles] = None
|
|
||||||
_LOAD_ATTEMPTED: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def load() -> Optional[HFTritonHandles]:
|
|
||||||
global _CACHED_HANDLES, _LOAD_ATTEMPTED
|
|
||||||
if _CACHED_HANDLES is not None:
|
|
||||||
return _CACHED_HANDLES
|
|
||||||
if _LOAD_ATTEMPTED:
|
|
||||||
# Previously failed; avoid spamming retries per call
|
|
||||||
return None
|
|
||||||
_LOAD_ATTEMPTED = True
|
|
||||||
try:
|
|
||||||
from kernels import get_kernel
|
|
||||||
|
|
||||||
tk = get_kernel("kernels-community/triton_kernels")
|
|
||||||
_CACHED_HANDLES = HFTritonHandles(
|
|
||||||
routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu
|
|
||||||
)
|
|
||||||
return _CACHED_HANDLES
|
|
||||||
except Exception:
|
|
||||||
# Keep None in cache state to prevent repeated fetch attempts
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def route_topk(logits, top_k: int):
|
|
||||||
handles = load()
|
|
||||||
if handles is None:
|
|
||||||
return None
|
|
||||||
return handles.routing.routing_torch(logits, n_expts_act=top_k)
|
|
||||||
|
|
||||||
|
|
||||||
def swiglu(x, alpha, limit=1.0, routing_data=None):
|
|
||||||
handles = load()
|
|
||||||
if handles is None:
|
|
||||||
return None
|
|
||||||
pc = handles.swiglu.PrecisionConfig(limit=limit)
|
|
||||||
return handles.swiglu.swiglu(x, alpha, pc, routing_data)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_ffn_forward_stub(
|
|
||||||
hidden_states, gate_linear, experts_module, top_k: int
|
|
||||||
) -> Tuple[object, object]:
|
|
||||||
"""
|
|
||||||
Temporary stub that uses kernels hub routing, but falls back to per-expert compute.
|
|
||||||
Returns (final_hidden_states, router_logits).
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
bsz, seqlen, hdim = hidden_states.shape
|
|
||||||
flat = hidden_states.view(-1, hdim)
|
|
||||||
router_logits = gate_linear(flat)
|
|
||||||
# Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down.
|
|
||||||
handles = load()
|
|
||||||
if handles is not None:
|
|
||||||
try:
|
|
||||||
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
|
||||||
router_logits, n_expts_act=top_k
|
|
||||||
)
|
|
||||||
# Prepare and cache expert weights: shapes [E, K, N]
|
|
||||||
import torch
|
|
||||||
|
|
||||||
E = experts_module.num_experts
|
|
||||||
dev = flat.device
|
|
||||||
dt = flat.dtype
|
|
||||||
if (
|
|
||||||
not hasattr(experts_module, "_stacked_w1")
|
|
||||||
or experts_module._stacked_w1.device != dev
|
|
||||||
or experts_module._stacked_w1.dtype != dt
|
|
||||||
):
|
|
||||||
W1 = []
|
|
||||||
W3 = []
|
|
||||||
W2 = []
|
|
||||||
for i in range(E):
|
|
||||||
exp = experts_module[i]
|
|
||||||
W1.append(exp.w1.weight.t())
|
|
||||||
W3.append(exp.w3.weight.t())
|
|
||||||
W2.append(exp.w2.weight.t())
|
|
||||||
experts_module._stacked_w1 = (
|
|
||||||
torch.stack(W1, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w3 = (
|
|
||||||
torch.stack(W3, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w2 = (
|
|
||||||
torch.stack(W2, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
W1 = experts_module._stacked_w1
|
|
||||||
W3 = experts_module._stacked_w3
|
|
||||||
# Fused up+gate: single matmul on concatenated weights [E, H, 2I]
|
|
||||||
W13 = getattr(experts_module, "_stacked_w13", None)
|
|
||||||
if (
|
|
||||||
W13 is None
|
|
||||||
or W13.device != dev
|
|
||||||
or W13.dtype != dt
|
|
||||||
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
|
|
||||||
):
|
|
||||||
W13 = torch.cat([W1, W3], dim=-1).contiguous()
|
|
||||||
experts_module._stacked_w13 = W13
|
|
||||||
Y13 = handles.matmul_ogs.matmul_ogs(
|
|
||||||
flat,
|
|
||||||
W13,
|
|
||||||
None,
|
|
||||||
routing_data=routing_data,
|
|
||||||
gather_indx=gather_idx,
|
|
||||||
scatter_indx=None,
|
|
||||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
|
||||||
)
|
|
||||||
# Use kernels hub SwiGLU for optimal MoE launch
|
|
||||||
sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
|
|
||||||
Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
|
|
||||||
# Down projection weights [E, inter, hidden]
|
|
||||||
W2 = experts_module._stacked_w2
|
|
||||||
# Down matmul with fused scatter back using scatter_indx
|
|
||||||
Out = handles.matmul_ogs.matmul_ogs(
|
|
||||||
Hidden,
|
|
||||||
W2,
|
|
||||||
None,
|
|
||||||
routing_data=routing_data,
|
|
||||||
gather_indx=None,
|
|
||||||
scatter_indx=scatter_idx,
|
|
||||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
|
||||||
gammas=routing_data.gate_scal,
|
|
||||||
)
|
|
||||||
return Out.view(bsz, seqlen, hdim), router_logits
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# Fallback naive path for correctness
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
|
||||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
|
||||||
topk_weight = topk_weight.to(flat.dtype)
|
|
||||||
x_rep = flat.repeat_interleave(top_k, dim=0)
|
|
||||||
y = torch.empty_like(x_rep)
|
|
||||||
flat_idx = topk_idx.view(-1)
|
|
||||||
for i in range(experts_module.num_experts):
|
|
||||||
expert = experts_module[i]
|
|
||||||
sel = flat_idx == i
|
|
||||||
if sel.any():
|
|
||||||
y[sel] = expert(x_rep[sel])
|
|
||||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
||||||
return y.reshape(bsz, seqlen, hdim), router_logits
|
|
||||||
@@ -10,7 +10,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
|||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton
|
from axolotl.kernels.moe import backends as _moe_backends
|
||||||
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||||
|
|
||||||
def mlp_forward(self, hidden_states):
|
def mlp_forward(self, hidden_states):
|
||||||
@@ -28,19 +28,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
|||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||||
backend = get_moe_backend_name(preferred)
|
backend = get_moe_backend_name(preferred)
|
||||||
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
|
if (
|
||||||
# Stub path: use kernels hub routing and fallback per-expert compute
|
|
||||||
try:
|
|
||||||
final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub(
|
|
||||||
hidden_states.view(batch_size, sequence_length, hidden_dim),
|
|
||||||
self.gate,
|
|
||||||
self.experts,
|
|
||||||
self.top_k,
|
|
||||||
)
|
|
||||||
return final_hidden_states, router_logits
|
|
||||||
except Exception as e:
|
|
||||||
warnings.warn(f"hf_triton backend failed, falling back to naive: {e}")
|
|
||||||
elif (
|
|
||||||
backend == MOEBackend.TORCH_GROUPED
|
backend == MOEBackend.TORCH_GROUPED
|
||||||
and not _moe_backends._probe_torch_grouped()
|
and not _moe_backends._probe_torch_grouped()
|
||||||
):
|
):
|
||||||
@@ -73,4 +61,23 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
||||||
MixtralSparseMoeBlock.forward = moe_forward
|
# Wrap forward to support optional torch_grouped backend via config
|
||||||
|
from axolotl.kernels.moe import torch_grouped as _tg
|
||||||
|
|
||||||
|
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||||
|
backend = get_moe_backend_name(preferred)
|
||||||
|
|
||||||
|
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
|
||||||
|
|
||||||
|
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
|
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||||
|
hidden_states, self.gate, self.experts, self.top_k
|
||||||
|
)
|
||||||
|
if y is None:
|
||||||
|
return moe_forward(self, hidden_states)
|
||||||
|
return y, router_logits
|
||||||
|
|
||||||
|
MixtralSparseMoeBlock.forward = moe_forward_grouped
|
||||||
|
else:
|
||||||
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
|
|||||||
@@ -133,10 +133,10 @@ class AxolotlInputConfig(
|
|||||||
vllm: VllmConfig | None = Field(
|
vllm: VllmConfig | None = Field(
|
||||||
default_factory=lambda: VllmConfig(),
|
default_factory=lambda: VllmConfig(),
|
||||||
)
|
)
|
||||||
moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field(
|
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user