update
This commit is contained in:
@@ -79,40 +79,20 @@ def moe_ffn_forward_stub(
|
|||||||
bsz, seqlen, hdim = hidden_states.shape
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
flat = hidden_states.view(-1, hdim)
|
flat = hidden_states.view(-1, hdim)
|
||||||
router_logits = gate_linear(flat)
|
router_logits = gate_linear(flat)
|
||||||
# use hub routing if available; otherwise fallback to softmax+topk
|
# For now, do not call routing to avoid extra overhead until
|
||||||
routed = None
|
# grouped GEMM integration is complete. Use the naive compute path
|
||||||
if available():
|
# for correctness and baseline performance.
|
||||||
try:
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
routed = route_topk(router_logits, top_k)
|
|
||||||
except Exception:
|
|
||||||
routed = None
|
|
||||||
if routed is None:
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
|
||||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
|
||||||
topk_weight = topk_weight.to(flat.dtype)
|
|
||||||
x_rep = flat.repeat_interleave(top_k, dim=0)
|
|
||||||
y = torch.empty_like(x_rep)
|
|
||||||
flat_idx = topk_idx.view(-1)
|
|
||||||
for i in range(experts_module.num_experts):
|
|
||||||
expert = experts_module[i]
|
|
||||||
y[flat_idx == i] = expert(x_rep[flat_idx == i])
|
|
||||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
||||||
return y.reshape(bsz, seqlen, hdim), router_logits
|
|
||||||
# If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired.
|
|
||||||
ex_routing_data, gather_idx, scatter_idx = routed
|
|
||||||
# Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu)
|
|
||||||
# For now, reconstruct the same result as naive path (no speedup but validates routing).
|
|
||||||
# We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k.
|
|
||||||
# Simpler: reuse naive computation for correctness; detailed integration will follow.
|
|
||||||
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
|
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype)
|
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
topk_weight = topk_weight.to(flat.dtype)
|
||||||
x_rep = flat.repeat_interleave(top_k, dim=0)
|
x_rep = flat.repeat_interleave(top_k, dim=0)
|
||||||
y = torch.empty_like(x_rep)
|
y = torch.empty_like(x_rep)
|
||||||
flat_idx = topk_idx.view(-1)
|
flat_idx = topk_idx.view(-1)
|
||||||
for i in range(experts_module.num_experts):
|
for i in range(experts_module.num_experts):
|
||||||
expert = experts_module[i]
|
expert = experts_module[i]
|
||||||
y[flat_idx == i] = expert(x_rep[flat_idx == i])
|
sel = flat_idx == i
|
||||||
|
if sel.any():
|
||||||
|
y[sel] = expert(x_rep[sel])
|
||||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
return y.reshape(bsz, seqlen, hdim), router_logits
|
return y.reshape(bsz, seqlen, hdim), router_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user