fix
This commit is contained in:
@@ -86,19 +86,42 @@ def moe_ffn_forward_stub(
|
||||
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
||||
router_logits, n_expts_act=top_k
|
||||
)
|
||||
# Prepare expert weights: shapes [E, K, N]
|
||||
# Prepare and cache expert weights: shapes [E, K, N]
|
||||
import torch
|
||||
|
||||
E = experts_module.num_experts
|
||||
K = hdim
|
||||
# up projections
|
||||
W1 = []
|
||||
W3 = []
|
||||
for i in range(E):
|
||||
exp = experts_module[i]
|
||||
# Linear weight is [out, in]; need [in, out]
|
||||
W1.append(exp.w1.weight.t())
|
||||
W3.append(exp.w3.weight.t())
|
||||
W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||
W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||
dev = flat.device
|
||||
dt = flat.dtype
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
W1 = []
|
||||
W3 = []
|
||||
W2 = []
|
||||
for i in range(E):
|
||||
exp = experts_module[i]
|
||||
W1.append(exp.w1.weight.t())
|
||||
W3.append(exp.w3.weight.t())
|
||||
W2.append(exp.w2.weight.t())
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(W1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(W3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(W2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
W1 = experts_module._stacked_w1
|
||||
W3 = experts_module._stacked_w3
|
||||
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
||||
# First matmul for w1: gather happens inside kernel using gather_indx
|
||||
Y1 = handles.matmul_ogs.matmul_ogs(
|
||||
@@ -123,8 +146,7 @@ def moe_ffn_forward_stub(
|
||||
# SwiGLU: silu(Y1) * Y3
|
||||
Hidden = F.silu(Y1) * Y3
|
||||
# Down projection weights [E, inter, hidden]
|
||||
W2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
||||
W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype)
|
||||
W2 = experts_module._stacked_w2
|
||||
# Down matmul with fused scatter back using scatter_indx
|
||||
Out = handles.matmul_ogs.matmul_ogs(
|
||||
Hidden,
|
||||
@@ -134,6 +156,7 @@ def moe_ffn_forward_stub(
|
||||
gather_indx=None,
|
||||
scatter_indx=scatter_idx,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
gammas=routing_data.gate_scal,
|
||||
)
|
||||
return Out.view(bsz, seqlen, hdim), router_logits
|
||||
except Exception:
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import (
|
||||
Field,
|
||||
StringConstraints,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
@@ -138,6 +139,8 @@ class AxolotlInputConfig(
|
||||
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||
},
|
||||
)
|
||||
|
||||
# Value is constrained by the Literal type; no normalization needed.
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = Field(
|
||||
|
||||
Reference in New Issue
Block a user