From 125e7b5fe628c98e23771db860a9ff255c08a78d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 18:57:13 -0400 Subject: [PATCH] fast path --- docs/moe_backends.md | 9 +-- scripts/bench_moe.py | 1 - src/axolotl/kernels/moe/backends.py | 6 +- src/axolotl/kernels/moe/hf_triton.py | 63 ++++++++++++++++++++- src/axolotl/loaders/patch_manager.py | 1 + src/axolotl/monkeypatch/mixtral/__init__.py | 5 +- src/axolotl/monkeypatch/multipack.py | 4 +- src/axolotl/utils/schemas/config.py | 6 ++ 8 files changed, 79 insertions(+), 16 deletions(-) diff --git a/docs/moe_backends.md b/docs/moe_backends.md index 150fa0eb0..50731d226 100644 --- a/docs/moe_backends.md +++ b/docs/moe_backends.md @@ -1,8 +1,8 @@ MoE Backends in Axolotl -Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via an environment variable: +Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML): -- AXOLOTL_MOE_BACKEND=auto|hf_triton|torch_grouped|naive +- Set `moe_backend: auto|hf_triton|torch_grouped|naive` Behavior - auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive. @@ -12,7 +12,8 @@ Behavior Notes - Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated. -- No changes to training scripts are required; Axolotl wraps Transformers Trainer; selection happens inside the model forward. +- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used. Example -AXOLOTL_MOE_BACKEND=hf_triton accelerate launch -m axolotl.cli.train path/to/config.yaml +moe_backend: hf_triton +accelerate launch -m axolotl.cli.train path/to/config.yaml diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 635ef6a79..5c5f07d78 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -156,7 +156,6 @@ def main(): ) # HF Triton (stub compute for now) - os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton") t_hf = forward_hf_triton y = t_hf(x, gate, experts, args.top_k) if y is not None: diff --git a/src/axolotl/kernels/moe/backends.py b/src/axolotl/kernels/moe/backends.py index c625718a2..210db6040 100644 --- a/src/axolotl/kernels/moe/backends.py +++ b/src/axolotl/kernels/moe/backends.py @@ -1,4 +1,3 @@ -import os import warnings from enum import Enum @@ -35,11 +34,10 @@ def _probe_hf_triton() -> bool: def get_moe_backend_name(preferred: str | None = None) -> MOEBackend: """ Resolve the desired MoE backend using, in order of precedence: - - explicit preferred argument - - environment variable AXOLOTL_MOE_BACKEND + - explicit preferred argument (e.g., from config) - auto detection """ - choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower() + choice = (preferred or "auto").lower() try: selected = MOEBackend(choice) except ValueError: diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py index b79c01cb3..a04af5359 100644 --- a/src/axolotl/kernels/moe/hf_triton.py +++ b/src/axolotl/kernels/moe/hf_triton.py @@ -79,9 +79,66 @@ def moe_ffn_forward_stub( bsz, seqlen, hdim = hidden_states.shape flat = hidden_states.view(-1, hdim) router_logits = gate_linear(flat) - # For now, do not call routing to avoid extra overhead until - # grouped GEMM integration is complete. Use the naive compute path - # for correctness and baseline performance. + # Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down. + handles = load() + if handles is not None: + try: + routing_data, gather_idx, scatter_idx = handles.routing.routing_torch( + router_logits, n_expts_act=top_k + ) + # Prepare expert weights: shapes [E, K, N] + E = experts_module.num_experts + K = hdim + # up projections + W1 = [] + W3 = [] + for i in range(E): + exp = experts_module[i] + # Linear weight is [out, in]; need [in, out] + W1.append(exp.w1.weight.t()) + W3.append(exp.w3.weight.t()) + W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype) + W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype) + # compute gathered inputs X_g according to gather_idx via matmul_ogs gather + # First matmul for w1: gather happens inside kernel using gather_indx + Y1 = handles.matmul_ogs.matmul_ogs( + flat, + W1, + None, + routing_data=routing_data, + gather_indx=gather_idx, + scatter_indx=None, + precision_config=handles.matmul_ogs.PrecisionConfig(), + ) + # Second matmul for w3 on the same gathered order + Y3 = handles.matmul_ogs.matmul_ogs( + flat, + W3, + None, + routing_data=routing_data, + gather_indx=gather_idx, + scatter_indx=None, + precision_config=handles.matmul_ogs.PrecisionConfig(), + ) + # SwiGLU: silu(Y1) * Y3 + Hidden = F.silu(Y1) * Y3 + # Down projection weights [E, inter, hidden] + W2 = [experts_module[i].w2.weight.t() for i in range(E)] + W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype) + # Down matmul with fused scatter back using scatter_indx + Out = handles.matmul_ogs.matmul_ogs( + Hidden, + W2, + None, + routing_data=routing_data, + gather_indx=None, + scatter_indx=scatter_idx, + precision_config=handles.matmul_ogs.PrecisionConfig(), + ) + return Out.view(bsz, seqlen, hdim), router_logits + except Exception: + pass + # Fallback naive path for correctness routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) topk_weight /= topk_weight.sum(dim=-1, keepdim=True) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index a5a630cb5..71933a9dd 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -269,6 +269,7 @@ class PatchManager: self.cfg.model_config_type, model_name=self.cfg.base_model, has_remote_code=has_remote_code, + cfg=self.cfg, ) if self.cfg.sample_packing: diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 988a1c6f9..8e4b08652 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -5,7 +5,7 @@ Patches to support multipack for mixtral import torch -def patch_mixtral_moe_forward_zero3() -> None: +def patch_mixtral_moe_forward_zero3(cfg=None) -> None: import warnings import torch.nn.functional as F @@ -26,7 +26,8 @@ def patch_mixtral_moe_forward_zero3() -> None: hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - backend = get_moe_backend_name() + preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None + backend = get_moe_backend_name(preferred) if backend == MOEBackend.HF_TRITON and _hf_triton.available(): # Stub path: use kernels hub routing and fallback per-expert compute try: diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index a32430d9f..03db780dc 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -46,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ ] -def patch_for_multipack(model_type, model_name=None, has_remote_code=False): +def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None): if has_remote_code: patch_remote(model_name) elif hasattr(transformers, "modeling_flash_attention_utils"): @@ -57,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False): transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data if model_type == "mixtral" and is_deepspeed_zero3_enabled(): - patch_mixtral_moe_forward_zero3() + patch_mixtral_moe_forward_zero3(cfg) def patch_remote(model_name): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d612ec8a5..3b4a890b5 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -132,6 +132,12 @@ class AxolotlInputConfig( vllm: VllmConfig | None = Field( default_factory=lambda: VllmConfig(), ) + moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field( + default=None, + json_schema_extra={ + "description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.", + }, + ) qat: QATConfig | None = None quantization: PTQConfig | None = None reward_model: bool | None = Field(