This commit is contained in:
Dan Saunders
2025-09-15 19:00:58 -04:00
parent 125e7b5fe6
commit cfefad1eea
2 changed files with 40 additions and 14 deletions

View File

@@ -86,19 +86,42 @@ def moe_ffn_forward_stub(
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
router_logits, n_expts_act=top_k
)
# Prepare expert weights: shapes [E, K, N]
# Prepare and cache expert weights: shapes [E, K, N]
import torch
E = experts_module.num_experts
K = hdim
# up projections
W1 = []
W3 = []
for i in range(E):
exp = experts_module[i]
# Linear weight is [out, in]; need [in, out]
W1.append(exp.w1.weight.t())
W3.append(exp.w3.weight.t())
W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype)
W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype)
dev = flat.device
dt = flat.dtype
if (
not hasattr(experts_module, "_stacked_w1")
or experts_module._stacked_w1.device != dev
or experts_module._stacked_w1.dtype != dt
):
W1 = []
W3 = []
W2 = []
for i in range(E):
exp = experts_module[i]
W1.append(exp.w1.weight.t())
W3.append(exp.w3.weight.t())
W2.append(exp.w2.weight.t())
experts_module._stacked_w1 = (
torch.stack(W1, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
experts_module._stacked_w3 = (
torch.stack(W3, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
experts_module._stacked_w2 = (
torch.stack(W2, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
W1 = experts_module._stacked_w1
W3 = experts_module._stacked_w3
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
# First matmul for w1: gather happens inside kernel using gather_indx
Y1 = handles.matmul_ogs.matmul_ogs(
@@ -123,8 +146,7 @@ def moe_ffn_forward_stub(
# SwiGLU: silu(Y1) * Y3
Hidden = F.silu(Y1) * Y3
# Down projection weights [E, inter, hidden]
W2 = [experts_module[i].w2.weight.t() for i in range(E)]
W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype)
W2 = experts_module._stacked_w2
# Down matmul with fused scatter back using scatter_indx
Out = handles.matmul_ogs.matmul_ogs(
Hidden,
@@ -134,6 +156,7 @@ def moe_ffn_forward_stub(
gather_indx=None,
scatter_indx=scatter_idx,
precision_config=handles.matmul_ogs.PrecisionConfig(),
gammas=routing_data.gate_scal,
)
return Out.view(bsz, seqlen, hdim), router_logits
except Exception:

View File

@@ -9,6 +9,7 @@ from pydantic import (
Field,
StringConstraints,
field_serializer,
field_validator,
model_validator,
)
@@ -138,6 +139,8 @@ class AxolotlInputConfig(
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
},
)
# Value is constrained by the Literal type; no normalization needed.
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(