shared expert detection

This commit is contained in:
Dan Saunders
2025-09-19 11:24:14 -04:00
parent bfc848f81d
commit 3bfed0aac8
2 changed files with 30 additions and 12 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import argparse
import csv
import math
import sys
import time
from dataclasses import dataclass
@@ -28,9 +27,15 @@ def _parse_int_list(value: str) -> List[int]:
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument("--batch-sizes", default="4,8,16", help="Comma separated batch sizes")
p.add_argument("--seq-lens", default="1024,2048", help="Comma separated sequence lengths")
p.add_argument("--experts", default="8,16,32,64", help="Comma separated expert counts")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
@@ -188,9 +193,7 @@ def _run_case(
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
)
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
@@ -215,10 +218,10 @@ def _run_case(
)
def _print_header(hidden: int, inter: int, dtype: torch.dtype, device: torch.device) -> None:
print(
f"Device={device} dtype={dtype} hidden={hidden} inter={inter}"
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"

View File

@@ -277,6 +277,17 @@ def moe_ffn_forward_grouped(
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None
if hasattr(experts_module, "shared_expert"):
shared_expert = experts_module.shared_expert
shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(experts_module, "shared_expert_gate", None)
if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input)
shared_out_flat.mul_(gate_vals.to(expert_dtype))
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
@@ -321,4 +332,8 @@ def moe_ffn_forward_grouped(
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, down_out)
return combined.view(bsz, seqlen, hdim), router_logits
output = combined.view(bsz, seqlen, hdim)
if shared_out_flat is not None:
output = output + shared_out_flat.view(bsz, seqlen, hdim)
return output, router_logits