fast path

This commit is contained in:
Dan Saunders
2025-09-15 18:57:13 -04:00
parent 479b6144df
commit 125e7b5fe6
8 changed files with 79 additions and 16 deletions

View File

@@ -1,8 +1,8 @@
MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via an environment variable:
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- AXOLOTL_MOE_BACKEND=auto|hf_triton|torch_grouped|naive
- Set `moe_backend: auto|hf_triton|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
@@ -12,7 +12,8 @@ Behavior
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
- No changes to training scripts are required; Axolotl wraps Transformers Trainer; selection happens inside the model forward.
- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used.
Example
AXOLOTL_MOE_BACKEND=hf_triton accelerate launch -m axolotl.cli.train path/to/config.yaml
moe_backend: hf_triton
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -156,7 +156,6 @@ def main():
)
# HF Triton (stub compute for now)
os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton")
t_hf = forward_hf_triton
y = t_hf(x, gate, experts, args.top_k)
if y is not None:

View File

@@ -1,4 +1,3 @@
import os
import warnings
from enum import Enum
@@ -35,11 +34,10 @@ def _probe_hf_triton() -> bool:
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
- explicit preferred argument
- environment variable AXOLOTL_MOE_BACKEND
- explicit preferred argument (e.g., from config)
- auto detection
"""
choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower()
choice = (preferred or "auto").lower()
try:
selected = MOEBackend(choice)
except ValueError:

View File

@@ -79,9 +79,66 @@ def moe_ffn_forward_stub(
bsz, seqlen, hdim = hidden_states.shape
flat = hidden_states.view(-1, hdim)
router_logits = gate_linear(flat)
# For now, do not call routing to avoid extra overhead until
# grouped GEMM integration is complete. Use the naive compute path
# for correctness and baseline performance.
# Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down.
handles = load()
if handles is not None:
try:
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
router_logits, n_expts_act=top_k
)
# Prepare expert weights: shapes [E, K, N]
E = experts_module.num_experts
K = hdim
# up projections
W1 = []
W3 = []
for i in range(E):
exp = experts_module[i]
# Linear weight is [out, in]; need [in, out]
W1.append(exp.w1.weight.t())
W3.append(exp.w3.weight.t())
W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype)
W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype)
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
# First matmul for w1: gather happens inside kernel using gather_indx
Y1 = handles.matmul_ogs.matmul_ogs(
flat,
W1,
None,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
# Second matmul for w3 on the same gathered order
Y3 = handles.matmul_ogs.matmul_ogs(
flat,
W3,
None,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
# SwiGLU: silu(Y1) * Y3
Hidden = F.silu(Y1) * Y3
# Down projection weights [E, inter, hidden]
W2 = [experts_module[i].w2.weight.t() for i in range(E)]
W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype)
# Down matmul with fused scatter back using scatter_indx
Out = handles.matmul_ogs.matmul_ogs(
Hidden,
W2,
None,
routing_data=routing_data,
gather_indx=None,
scatter_indx=scatter_idx,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
return Out.view(bsz, seqlen, hdim), router_logits
except Exception:
pass
# Fallback naive path for correctness
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)

View File

@@ -269,6 +269,7 @@ class PatchManager:
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
cfg=self.cfg,
)
if self.cfg.sample_packing:

View File

@@ -5,7 +5,7 @@ Patches to support multipack for mixtral
import torch
def patch_mixtral_moe_forward_zero3() -> None:
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
import warnings
import torch.nn.functional as F
@@ -26,7 +26,8 @@ def patch_mixtral_moe_forward_zero3() -> None:
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
backend = get_moe_backend_name()
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
# Stub path: use kernels hub routing and fallback per-expert compute
try:

View File

@@ -46,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
@@ -57,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
patch_mixtral_moe_forward_zero3(cfg)
def patch_remote(model_name):

View File

@@ -132,6 +132,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field(
default=None,
json_schema_extra={
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
},
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(