update
This commit is contained in:
@@ -79,40 +79,20 @@ def moe_ffn_forward_stub(
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
flat = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(flat)
|
||||
# use hub routing if available; otherwise fallback to softmax+topk
|
||||
routed = None
|
||||
if available():
|
||||
try:
|
||||
routed = route_topk(router_logits, top_k)
|
||||
except Exception:
|
||||
routed = None
|
||||
if routed is None:
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
topk_weight = topk_weight.to(flat.dtype)
|
||||
x_rep = flat.repeat_interleave(top_k, dim=0)
|
||||
y = torch.empty_like(x_rep)
|
||||
flat_idx = topk_idx.view(-1)
|
||||
for i in range(experts_module.num_experts):
|
||||
expert = experts_module[i]
|
||||
y[flat_idx == i] = expert(x_rep[flat_idx == i])
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
return y.reshape(bsz, seqlen, hdim), router_logits
|
||||
# If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired.
|
||||
ex_routing_data, gather_idx, scatter_idx = routed
|
||||
# Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu)
|
||||
# For now, reconstruct the same result as naive path (no speedup but validates routing).
|
||||
# We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k.
|
||||
# Simpler: reuse naive computation for correctness; detailed integration will follow.
|
||||
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
# For now, do not call routing to avoid extra overhead until
|
||||
# grouped GEMM integration is complete. Use the naive compute path
|
||||
# for correctness and baseline performance.
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
topk_weight = topk_weight.to(flat.dtype)
|
||||
x_rep = flat.repeat_interleave(top_k, dim=0)
|
||||
y = torch.empty_like(x_rep)
|
||||
flat_idx = topk_idx.view(-1)
|
||||
for i in range(experts_module.num_experts):
|
||||
expert = experts_module[i]
|
||||
y[flat_idx == i] = expert(x_rep[flat_idx == i])
|
||||
sel = flat_idx == i
|
||||
if sel.any():
|
||||
y[sel] = expert(x_rep[sel])
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
return y.reshape(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Reference in New Issue
Block a user