This commit is contained in:
Dan Saunders
2025-09-15 18:48:43 -04:00
parent 0d689bb421
commit 68da65cba2

View File

@@ -79,40 +79,20 @@ def moe_ffn_forward_stub(
bsz, seqlen, hdim = hidden_states.shape
flat = hidden_states.view(-1, hdim)
router_logits = gate_linear(flat)
# use hub routing if available; otherwise fallback to softmax+topk
routed = None
if available():
try:
routed = route_topk(router_logits, top_k)
except Exception:
routed = None
if routed is None:
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.to(flat.dtype)
x_rep = flat.repeat_interleave(top_k, dim=0)
y = torch.empty_like(x_rep)
flat_idx = topk_idx.view(-1)
for i in range(experts_module.num_experts):
expert = experts_module[i]
y[flat_idx == i] = expert(x_rep[flat_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
return y.reshape(bsz, seqlen, hdim), router_logits
# If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired.
ex_routing_data, gather_idx, scatter_idx = routed
# Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu)
# For now, reconstruct the same result as naive path (no speedup but validates routing).
# We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k.
# Simpler: reuse naive computation for correctness; detailed integration will follow.
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
# For now, do not call routing to avoid extra overhead until
# grouped GEMM integration is complete. Use the naive compute path
# for correctness and baseline performance.
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.to(flat.dtype)
x_rep = flat.repeat_interleave(top_k, dim=0)
y = torch.empty_like(x_rep)
flat_idx = topk_idx.view(-1)
for i in range(experts_module.num_experts):
expert = experts_module[i]
y[flat_idx == i] = expert(x_rep[flat_idx == i])
sel = flat_idx == i
if sel.any():
y[sel] = expert(x_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
return y.reshape(bsz, seqlen, hdim), router_logits