diff --git a/docs/moe_backends.md b/docs/moe_backends.md index 50731d226..6fff4f4a1 100644 --- a/docs/moe_backends.md +++ b/docs/moe_backends.md @@ -2,18 +2,17 @@ MoE Backends in Axolotl Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML): -- Set `moe_backend: auto|hf_triton|torch_grouped|naive` +- Set `moe_backend: auto|torch_grouped|naive` Behavior -- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive. -- hf_triton: uses the Hugging Face kernels hub (kernels-community/triton_kernels) when available. -- torch_grouped: targets PyTorch 2.8+ grouped GEMM. +- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive. +- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended). - naive: keeps the reference per-expert loop. Notes -- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated. -- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used. +- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop. +- No changes to training scripts are required; selection happens inside the model forward. Example -moe_backend: hf_triton +moe_backend: torch_grouped accelerate launch -m axolotl.cli.train path/to/config.yaml diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 14653aba1..287a9cbe3 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -1,6 +1,5 @@ #!/usr/bin/env python import argparse -import os import time import torch @@ -52,20 +51,6 @@ def forward_naive( return y.view(bsz, seqlen, hdim) -def forward_hf_triton( - hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int -): - try: - from axolotl.kernels.moe import hf_triton as _hf - except Exception: - return None - try: - y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k) - return y - except Exception: - return None - - def bench(fn, *args, iters=50, warmup=10, sync=True): # warmup for _ in range(warmup): @@ -159,33 +144,6 @@ def main(): with torch.no_grad(): y_ref = forward_naive(x, gate, experts, args.top_k) - # HF Triton - t_hf = forward_hf_triton - y = t_hf(x, gate, experts, args.top_k) - if y is not None: - t_ms = bench( - t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup - ) - tflops = flops_total / ((t_ms / 1000.0) * 1e12) - speedup = t_naive / t_ms - print( - f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" - ) - # parity for hf_triton vs naive - with torch.no_grad(): - y_fast = y - y_ref32 = y_ref.float() - y_fast32 = y_fast.float() - diff = (y_ref32 - y_fast32).abs() - max_abs = diff.max().item() - mean_abs = diff.mean().item() - rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() - print( - f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" - ) - else: - print("hf_triton\tN/A (kernels hub not available)") - # torch_grouped backend (PyTorch 2.8+) try: from axolotl.kernels.moe import torch_grouped as tg diff --git a/src/axolotl/kernels/moe/backends.py b/src/axolotl/kernels/moe/backends.py index 210db6040..49f365983 100644 --- a/src/axolotl/kernels/moe/backends.py +++ b/src/axolotl/kernels/moe/backends.py @@ -4,7 +4,6 @@ from enum import Enum class MOEBackend(str, Enum): AUTO = "auto" - HF_TRITON = "hf_triton" TORCH_GROUPED = "torch_grouped" NAIVE = "naive" @@ -20,17 +19,6 @@ def _probe_torch_grouped() -> bool: return False -def _probe_hf_triton() -> bool: - try: - # The hub loads kernels lazily; this import is a light probe. - import importlib - - importlib.import_module("kernels") - return True - except Exception: - return False - - def get_moe_backend_name(preferred: str | None = None) -> MOEBackend: """ Resolve the desired MoE backend using, in order of precedence: @@ -47,17 +35,10 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend: if selected == MOEBackend.AUTO: if _probe_torch_grouped(): return MOEBackend.TORCH_GROUPED - if _probe_hf_triton(): - return MOEBackend.HF_TRITON return MOEBackend.NAIVE if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped(): warnings.warn( - "torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive" + "torch_grouped requested but torch>=2.8 not detected; falling back to naive" ) - return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE - if selected == MOEBackend.HF_TRITON and not _probe_hf_triton(): - warnings.warn( - "hf_triton requested but kernels hub not available; falling back to torch_grouped/naive" - ) - return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE + return MOEBackend.NAIVE return selected diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py deleted file mode 100644 index 1fb3d084d..000000000 --- a/src/axolotl/kernels/moe/hf_triton.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -Adapter for Hugging Face kernels hub (kernels-community/triton_kernels). -This file provides light probes and placeholders for future integration. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - - -@dataclass -class HFTritonHandles: - routing: Any - matmul_ogs: Any - swiglu: Any - - -def available() -> bool: - try: - import kernels # noqa: F401 - - return True - except Exception: - return False - - -# Cache loaded handles so we don't trigger repeated hub fetches -_CACHED_HANDLES: Optional[HFTritonHandles] = None -_LOAD_ATTEMPTED: bool = False - - -def load() -> Optional[HFTritonHandles]: - global _CACHED_HANDLES, _LOAD_ATTEMPTED - if _CACHED_HANDLES is not None: - return _CACHED_HANDLES - if _LOAD_ATTEMPTED: - # Previously failed; avoid spamming retries per call - return None - _LOAD_ATTEMPTED = True - try: - from kernels import get_kernel - - tk = get_kernel("kernels-community/triton_kernels") - _CACHED_HANDLES = HFTritonHandles( - routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu - ) - return _CACHED_HANDLES - except Exception: - # Keep None in cache state to prevent repeated fetch attempts - return None - - -def route_topk(logits, top_k: int): - handles = load() - if handles is None: - return None - return handles.routing.routing_torch(logits, n_expts_act=top_k) - - -def swiglu(x, alpha, limit=1.0, routing_data=None): - handles = load() - if handles is None: - return None - pc = handles.swiglu.PrecisionConfig(limit=limit) - return handles.swiglu.swiglu(x, alpha, pc, routing_data) - - -def moe_ffn_forward_stub( - hidden_states, gate_linear, experts_module, top_k: int -) -> Tuple[object, object]: - """ - Temporary stub that uses kernels hub routing, but falls back to per-expert compute. - Returns (final_hidden_states, router_logits). - """ - import torch - import torch.nn.functional as F - - bsz, seqlen, hdim = hidden_states.shape - flat = hidden_states.view(-1, hdim) - router_logits = gate_linear(flat) - # Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down. - handles = load() - if handles is not None: - try: - routing_data, gather_idx, scatter_idx = handles.routing.routing_torch( - router_logits, n_expts_act=top_k - ) - # Prepare and cache expert weights: shapes [E, K, N] - import torch - - E = experts_module.num_experts - dev = flat.device - dt = flat.dtype - if ( - not hasattr(experts_module, "_stacked_w1") - or experts_module._stacked_w1.device != dev - or experts_module._stacked_w1.dtype != dt - ): - W1 = [] - W3 = [] - W2 = [] - for i in range(E): - exp = experts_module[i] - W1.append(exp.w1.weight.t()) - W3.append(exp.w3.weight.t()) - W2.append(exp.w2.weight.t()) - experts_module._stacked_w1 = ( - torch.stack(W1, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w3 = ( - torch.stack(W3, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(W2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - W1 = experts_module._stacked_w1 - W3 = experts_module._stacked_w3 - # Fused up+gate: single matmul on concatenated weights [E, H, 2I] - W13 = getattr(experts_module, "_stacked_w13", None) - if ( - W13 is None - or W13.device != dev - or W13.dtype != dt - or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1]) - ): - W13 = torch.cat([W1, W3], dim=-1).contiguous() - experts_module._stacked_w13 = W13 - Y13 = handles.matmul_ogs.matmul_ogs( - flat, - W13, - None, - routing_data=routing_data, - gather_indx=gather_idx, - scatter_indx=None, - precision_config=handles.matmul_ogs.PrecisionConfig(), - ) - # Use kernels hub SwiGLU for optimal MoE launch - sw_pc = handles.swiglu.PrecisionConfig(limit=1.0) - Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data) - # Down projection weights [E, inter, hidden] - W2 = experts_module._stacked_w2 - # Down matmul with fused scatter back using scatter_indx - Out = handles.matmul_ogs.matmul_ogs( - Hidden, - W2, - None, - routing_data=routing_data, - gather_indx=None, - scatter_indx=scatter_idx, - precision_config=handles.matmul_ogs.PrecisionConfig(), - gammas=routing_data.gate_scal, - ) - return Out.view(bsz, seqlen, hdim), router_logits - except Exception: - pass - # Fallback naive path for correctness - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight /= topk_weight.sum(dim=-1, keepdim=True) - topk_weight = topk_weight.to(flat.dtype) - x_rep = flat.repeat_interleave(top_k, dim=0) - y = torch.empty_like(x_rep) - flat_idx = topk_idx.view(-1) - for i in range(experts_module.num_experts): - expert = experts_module[i] - sel = flat_idx == i - if sel.any(): - y[sel] = expert(x_rep[sel]) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - return y.reshape(bsz, seqlen, hdim), router_logits diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 8e4b08652..2fb77869e 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -10,7 +10,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None: import torch.nn.functional as F - from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton + from axolotl.kernels.moe import backends as _moe_backends from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name def mlp_forward(self, hidden_states): @@ -28,19 +28,7 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None: router_logits = self.gate(hidden_states) preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None backend = get_moe_backend_name(preferred) - if backend == MOEBackend.HF_TRITON and _hf_triton.available(): - # Stub path: use kernels hub routing and fallback per-expert compute - try: - final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub( - hidden_states.view(batch_size, sequence_length, hidden_dim), - self.gate, - self.experts, - self.top_k, - ) - return final_hidden_states, router_logits - except Exception as e: - warnings.warn(f"hf_triton backend failed, falling back to naive: {e}") - elif ( + if ( backend == MOEBackend.TORCH_GROUPED and not _moe_backends._probe_torch_grouped() ): @@ -73,4 +61,23 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None: ) MixtralBlockSparseTop2MLP.forward = mlp_forward - MixtralSparseMoeBlock.forward = moe_forward + # Wrap forward to support optional torch_grouped backend via config + from axolotl.kernels.moe import torch_grouped as _tg + + preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None + backend = get_moe_backend_name(preferred) + + if backend == MOEBackend.TORCH_GROUPED and _tg.available(): + + def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor: + bsz, seqlen, hdim = hidden_states.shape + y, router_logits = _tg.moe_ffn_forward_grouped( + hidden_states, self.gate, self.experts, self.top_k + ) + if y is None: + return moe_forward(self, hidden_states) + return y, router_logits + + MixtralSparseMoeBlock.forward = moe_forward_grouped + else: + MixtralSparseMoeBlock.forward = moe_forward diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 35e7cf5ca..a5cdacf9e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -133,10 +133,10 @@ class AxolotlInputConfig( vllm: VllmConfig | None = Field( default_factory=lambda: VllmConfig(), ) - moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field( + moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field( default=None, json_schema_extra={ - "description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.", + "description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.", }, )