fix
This commit is contained in:
@@ -86,19 +86,42 @@ def moe_ffn_forward_stub(
|
|||||||
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
routing_data, gather_idx, scatter_idx = handles.routing.routing_torch(
|
||||||
router_logits, n_expts_act=top_k
|
router_logits, n_expts_act=top_k
|
||||||
)
|
)
|
||||||
# Prepare expert weights: shapes [E, K, N]
|
# Prepare and cache expert weights: shapes [E, K, N]
|
||||||
|
import torch
|
||||||
|
|
||||||
E = experts_module.num_experts
|
E = experts_module.num_experts
|
||||||
K = hdim
|
dev = flat.device
|
||||||
# up projections
|
dt = flat.dtype
|
||||||
W1 = []
|
if (
|
||||||
W3 = []
|
not hasattr(experts_module, "_stacked_w1")
|
||||||
for i in range(E):
|
or experts_module._stacked_w1.device != dev
|
||||||
exp = experts_module[i]
|
or experts_module._stacked_w1.dtype != dt
|
||||||
# Linear weight is [out, in]; need [in, out]
|
):
|
||||||
W1.append(exp.w1.weight.t())
|
W1 = []
|
||||||
W3.append(exp.w3.weight.t())
|
W3 = []
|
||||||
W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype)
|
W2 = []
|
||||||
W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype)
|
for i in range(E):
|
||||||
|
exp = experts_module[i]
|
||||||
|
W1.append(exp.w1.weight.t())
|
||||||
|
W3.append(exp.w3.weight.t())
|
||||||
|
W2.append(exp.w2.weight.t())
|
||||||
|
experts_module._stacked_w1 = (
|
||||||
|
torch.stack(W1, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w3 = (
|
||||||
|
torch.stack(W3, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w2 = (
|
||||||
|
torch.stack(W2, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
W1 = experts_module._stacked_w1
|
||||||
|
W3 = experts_module._stacked_w3
|
||||||
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
||||||
# First matmul for w1: gather happens inside kernel using gather_indx
|
# First matmul for w1: gather happens inside kernel using gather_indx
|
||||||
Y1 = handles.matmul_ogs.matmul_ogs(
|
Y1 = handles.matmul_ogs.matmul_ogs(
|
||||||
@@ -123,8 +146,7 @@ def moe_ffn_forward_stub(
|
|||||||
# SwiGLU: silu(Y1) * Y3
|
# SwiGLU: silu(Y1) * Y3
|
||||||
Hidden = F.silu(Y1) * Y3
|
Hidden = F.silu(Y1) * Y3
|
||||||
# Down projection weights [E, inter, hidden]
|
# Down projection weights [E, inter, hidden]
|
||||||
W2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
W2 = experts_module._stacked_w2
|
||||||
W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype)
|
|
||||||
# Down matmul with fused scatter back using scatter_indx
|
# Down matmul with fused scatter back using scatter_indx
|
||||||
Out = handles.matmul_ogs.matmul_ogs(
|
Out = handles.matmul_ogs.matmul_ogs(
|
||||||
Hidden,
|
Hidden,
|
||||||
@@ -134,6 +156,7 @@ def moe_ffn_forward_stub(
|
|||||||
gather_indx=None,
|
gather_indx=None,
|
||||||
scatter_indx=scatter_idx,
|
scatter_indx=scatter_idx,
|
||||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||||
|
gammas=routing_data.gate_scal,
|
||||||
)
|
)
|
||||||
return Out.view(bsz, seqlen, hdim), router_logits
|
return Out.view(bsz, seqlen, hdim), router_logits
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
StringConstraints,
|
StringConstraints,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
|
field_validator,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,6 +139,8 @@ class AxolotlInputConfig(
|
|||||||
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
"description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Value is constrained by the Literal type; no normalization needed.
|
||||||
qat: QATConfig | None = None
|
qat: QATConfig | None = None
|
||||||
quantization: PTQConfig | None = None
|
quantization: PTQConfig | None = None
|
||||||
reward_model: bool | None = Field(
|
reward_model: bool | None = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user