fast path
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
MoE Backends in Axolotl
|
MoE Backends in Axolotl
|
||||||
|
|
||||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via an environment variable:
|
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
||||||
|
|
||||||
- AXOLOTL_MOE_BACKEND=auto|hf_triton|torch_grouped|naive
|
- Set `moe_backend: auto|hf_triton|torch_grouped|naive`
|
||||||
|
|
||||||
Behavior
|
Behavior
|
||||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
|
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
|
||||||
@@ -12,7 +12,8 @@ Behavior
|
|||||||
|
|
||||||
Notes
|
Notes
|
||||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
|
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
|
||||||
- No changes to training scripts are required; Axolotl wraps Transformers Trainer; selection happens inside the model forward.
|
- No changes to training scripts are required; selection happens inside the model forward. The `AXOLOTL_MOE_BACKEND` environment variable is no longer used.
|
||||||
|
|
||||||
Example
|
Example
|
||||||
AXOLOTL_MOE_BACKEND=hf_triton accelerate launch -m axolotl.cli.train path/to/config.yaml
|
moe_backend: hf_triton
|
||||||
|
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||||
|
|||||||
@@ -156,7 +156,6 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# HF Triton (stub compute for now)
|
# HF Triton (stub compute for now)
|
||||||
os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton")
|
|
||||||
t_hf = forward_hf_triton
|
t_hf = forward_hf_triton
|
||||||
y = t_hf(x, gate, experts, args.top_k)
|
y = t_hf(x, gate, experts, args.top_k)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -35,11 +34,10 @@ def _probe_hf_triton() -> bool:
|
|||||||
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||||
"""
|
"""
|
||||||
Resolve the desired MoE backend using, in order of precedence:
|
Resolve the desired MoE backend using, in order of precedence:
|
||||||
- explicit preferred argument
|
- explicit preferred argument (e.g., from config)
|
||||||
- environment variable AXOLOTL_MOE_BACKEND
|
|
||||||
- auto detection
|
- auto detection
|
||||||
"""
|
"""
|
||||||
choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower()
|
choice = (preferred or "auto").lower()
|
||||||
try:
|
try:
|
||||||
selected = MOEBackend(choice)
|
selected = MOEBackend(choice)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -79,9 +79,66 @@ def moe_ffn_forward_stub(
|
|||||||
bsz, seqlen, hdim = hidden_states.shape
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
flat = hidden_states.view(-1, hdim)
|
flat = hidden_states.view(-1, hdim)
|
||||||
router_logits = gate_linear(flat)
|
router_logits = gate_linear(flat)
|
||||||
# For now, do not call routing to avoid extra overhead until
|
# Fast path via kernels hub: route tokens, do grouped GEMMs for up, gate, and down.
|
||||||
# grouped GEMM integration is complete. Use the naive compute path
|
handles = load()
|
||||||
# for correctness and baseline performance.
|
if handles is not None:
|
||||||
|
try:
|
||||||
|
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
||||||
|
router_logits, n_expts_act=top_k
|
||||||
|
)
|
||||||
|
# Prepare expert weights: shapes [E, K, N]
|
||||||
|
E = experts_module.num_experts
|
||||||
|
K = hdim
|
||||||
|
# up projections
|
||||||
|
W1 = []
|
||||||
|
W3 = []
|
||||||
|
for i in range(E):
|
||||||
|
exp = experts_module[i]
|
||||||
|
# Linear weight is [out, in]; need [in, out]
|
||||||
|
W1.append(exp.w1.weight.t())
|
||||||
|
W3.append(exp.w3.weight.t())
|
||||||
|
W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||||
|
W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||||
|
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
||||||
|
# First matmul for w1: gather happens inside kernel using gather_indx
|
||||||
|
Y1 = handles.matmul_ogs.matmul_ogs(
|
||||||
|
flat,
|
||||||
|
W1,
|
||||||
|
None,
|
||||||
|
routing_data=routing_data,
|
||||||
|
gather_indx=gather_idx,
|
||||||
|
scatter_indx=None,
|
||||||
|
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||||
|
)
|
||||||
|
# Second matmul for w3 on the same gathered order
|
||||||
|
Y3 = handles.matmul_ogs.matmul_ogs(
|
||||||
|
flat,
|
||||||
|
W3,
|
||||||
|
None,
|
||||||
|
routing_data=routing_data,
|
||||||
|
gather_indx=gather_idx,
|
||||||
|
scatter_indx=None,
|
||||||
|
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||||
|
)
|
||||||
|
# SwiGLU: silu(Y1) * Y3
|
||||||
|
Hidden = F.silu(Y1) * Y3
|
||||||
|
# Down projection weights [E, inter, hidden]
|
||||||
|
W2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
||||||
|
W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||||
|
# Down matmul with fused scatter back using scatter_indx
|
||||||
|
Out = handles.matmul_ogs.matmul_ogs(
|
||||||
|
Hidden,
|
||||||
|
W2,
|
||||||
|
None,
|
||||||
|
routing_data=routing_data,
|
||||||
|
gather_indx=None,
|
||||||
|
scatter_indx=scatter_idx,
|
||||||
|
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||||
|
)
|
||||||
|
return Out.view(bsz, seqlen, hdim), router_logits
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback naive path for correctness
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
|||||||
@@ -269,6 +269,7 @@ class PatchManager:
|
|||||||
self.cfg.model_config_type,
|
self.cfg.model_config_type,
|
||||||
model_name=self.cfg.base_model,
|
model_name=self.cfg.base_model,
|
||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
|
cfg=self.cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.sample_packing:
|
if self.cfg.sample_packing:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Patches to support multipack for mixtral
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def patch_mixtral_moe_forward_zero3() -> None:
|
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -26,7 +26,8 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
backend = get_moe_backend_name()
|
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||||
|
backend = get_moe_backend_name(preferred)
|
||||||
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
|
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
|
||||||
# Stub path: use kernels hub routing and fallback per-expert compute
|
# Stub path: use kernels hub routing and fallback per-expert compute
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
|
||||||
if has_remote_code:
|
if has_remote_code:
|
||||||
patch_remote(model_name)
|
patch_remote(model_name)
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
@@ -57,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
|||||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||||
|
|
||||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3(cfg)
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name):
|
def patch_remote(model_name):
|
||||||
|
|||||||
@@ -132,6 +132,12 @@ class AxolotlInputConfig(
|
|||||||
vllm: VllmConfig | None = Field(
|
vllm: VllmConfig | None = Field(
|
||||||
default_factory=lambda: VllmConfig(),
|
default_factory=lambda: VllmConfig(),
|
||||||
)
|
)
|
||||||
|
moe_backend: Literal["auto", "hf_triton", "torch_grouped", "naive"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||||
|
},
|
||||||
|
)
|
||||||
qat: QATConfig | None = None
|
qat: QATConfig | None = None
|
||||||
quantization: PTQConfig | None = None
|
quantization: PTQConfig | None = None
|
||||||
reward_model: bool | None = Field(
|
reward_model: bool | None = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user