fast path
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
MoE Backends in Axolotl
|
||||
|
||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via an environment variable:
|
||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
||||
|
||||
- AXOLOTL_MOE_BACKEND=auto|hf_triton|torch_grouped|naive
|
||||
- Set `moe_backend: auto|hf_triton|torch_grouped|naive`
|
||||
|
||||
Behavior
|
||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
|
||||
@@ -12,7 +12,8 @@ Behavior
|
||||
|
||||
Notes
|
||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
|
||||
- No changes to training scripts are required; Axolotl wraps Transformers Trainer; selection happens inside the model forward.
|
||||
- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used.
|
||||
|
||||
Example
|
||||
AXOLOTL_MOE_BACKEND=hf_triton accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||
moe_backend: hf_triton
|
||||
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||
|
||||
@@ -156,7 +156,6 @@ def main():
|
||||
)
|
||||
|
||||
# HF Triton (stub compute for now)
|
||||
os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton")
|
||||
t_hf = forward_hf_triton
|
||||
y = t_hf(x, gate, experts, args.top_k)
|
||||
if y is not None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import warnings
|
||||
from enum import Enum
|
||||
|
||||
@@ -35,11 +34,10 @@ def _probe_hf_triton() -> bool:
|
||||
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||
"""
|
||||
Resolve the desired MoE backend using, in order of precedence:
|
||||
- explicit preferred argument
|
||||
- environment variable AXOLOTL_MOE_BACKEND
|
||||
- explicit preferred argument (e.g., from config)
|
||||
- auto detection
|
||||
"""
|
||||
choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower()
|
||||
choice = (preferred or "auto").lower()
|
||||
try:
|
||||
selected = MOEBackend(choice)
|
||||
except ValueError:
|
||||
|
||||
@@ -79,9 +79,66 @@ def moe_ffn_forward_stub(
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
flat = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(flat)
|
||||
# For now, do not call routing to avoid extra overhead until
|
||||
# grouped GEMM integration is complete. Use the naive compute path
|
||||
# for correctness and baseline performance.
|
||||
# Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down.
|
||||
handles = load()
|
||||
if handles is not None:
|
||||
try:
|
||||
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
||||
router_logits, n_expts_act=top_k
|
||||
)
|
||||
# Prepare expert weights: shapes [E, K, N]
|
||||
E = experts_module.num_experts
|
||||
K = hdim
|
||||
# up projections
|
||||
W1 = []
|
||||
W3 = []
|
||||
for i in range(E):
|
||||
exp = experts_module[i]
|
||||
# Linear weight is [out, in]; need [in, out]
|
||||
W1.append(exp.w1.weight.t())
|
||||
W3.append(exp.w3.weight.t())
|
||||
W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||
W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
||||
# First matmul for w1: gather happens inside kernel using gather_indx
|
||||
Y1 = handles.matmul_ogs.matmul_ogs(
|
||||
flat,
|
||||
W1,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=None,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
)
|
||||
# Second matmul for w3 on the same gathered order
|
||||
Y3 = handles.matmul_ogs.matmul_ogs(
|
||||
flat,
|
||||
W3,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=None,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
)
|
||||
# SwiGLU: silu(Y1) * Y3
|
||||
Hidden = F.silu(Y1) * Y3
|
||||
# Down projection weights [E, inter, hidden]
|
||||
W2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
||||
W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||
# Down matmul with fused scatter back using scatter_indx
|
||||
Out = handles.matmul_ogs.matmul_ogs(
|
||||
Hidden,
|
||||
W2,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=None,
|
||||
scatter_indx=scatter_idx,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
)
|
||||
return Out.view(bsz, seqlen, hdim), router_logits
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback naive path for correctness
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
@@ -269,6 +269,7 @@ class PatchManager:
|
||||
self.cfg.model_config_type,
|
||||
model_name=self.cfg.base_model,
|
||||
has_remote_code=has_remote_code,
|
||||
cfg=self.cfg,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
|
||||
@@ -5,7 +5,7 @@ Patches to support multipack for mixtral
|
||||
import torch
|
||||
|
||||
|
||||
def patch_mixtral_moe_forward_zero3() -> None:
|
||||
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
import warnings
|
||||
|
||||
import torch.nn.functional as F
|
||||
@@ -26,7 +26,8 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
backend = get_moe_backend_name()
|
||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||
backend = get_moe_backend_name(preferred)
|
||||
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
|
||||
# Stub path: use kernels hub routing and fallback per-expert compute
|
||||
try:
|
||||
|
||||
@@ -46,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
@@ -57,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
patch_mixtral_moe_forward_zero3(cfg)
|
||||
|
||||
|
||||
def patch_remote(model_name):
|
||||
|
||||
@@ -132,6 +132,12 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(),
|
||||
)
|
||||
moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||
},
|
||||
)
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = Field(
|
||||
|
||||
Reference in New Issue
Block a user