shared expert detection
This commit is contained in:
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -28,9 +27,15 @@ def _parse_int_list(value: str) -> List[int]:
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
|
||||
p.add_argument("--batch-sizes", default="4,8,16", help="Comma separated batch sizes")
|
||||
p.add_argument("--seq-lens", default="1024,2048", help="Comma separated sequence lengths")
|
||||
p.add_argument("--experts", default="8,16,32,64", help="Comma separated expert counts")
|
||||
p.add_argument(
|
||||
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
|
||||
)
|
||||
p.add_argument(
|
||||
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
|
||||
)
|
||||
p.add_argument(
|
||||
"--experts", default="8,16,32,64", help="Comma separated expert counts"
|
||||
)
|
||||
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
@@ -188,9 +193,7 @@ def _run_case(
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (
|
||||
(diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
)
|
||||
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
|
||||
tokens = bsz * seq
|
||||
flops = _estimate_flops(tokens, hidden, inter, top_k)
|
||||
@@ -215,10 +218,10 @@ def _run_case(
|
||||
)
|
||||
|
||||
|
||||
def _print_header(hidden: int, inter: int, dtype: torch.dtype, device: torch.device) -> None:
|
||||
print(
|
||||
f"Device={device} dtype={dtype} hidden={hidden} inter={inter}"
|
||||
)
|
||||
def _print_header(
|
||||
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
|
||||
) -> None:
|
||||
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
|
||||
print(
|
||||
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
||||
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
||||
|
||||
@@ -277,6 +277,17 @@ def moe_ffn_forward_grouped(
|
||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||
|
||||
shared_out_flat: Optional[torch.Tensor] = None
|
||||
if hasattr(experts_module, "shared_expert"):
|
||||
shared_expert = experts_module.shared_expert
|
||||
shared_out_flat = shared_expert(x_flat)
|
||||
shared_out_flat = shared_out_flat.to(expert_dtype)
|
||||
shared_gate = getattr(experts_module, "shared_expert_gate", None)
|
||||
if shared_gate is not None:
|
||||
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
|
||||
gate_vals = torch.sigmoid(gate_input)
|
||||
shared_out_flat.mul_(gate_vals.to(expert_dtype))
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
@@ -321,4 +332,8 @@ def moe_ffn_forward_grouped(
|
||||
|
||||
combined = torch.zeros_like(x_flat)
|
||||
combined.scatter_add_(0, gather_index, down_out)
|
||||
return combined.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
output = combined.view(bsz, seqlen, hdim)
|
||||
if shared_out_flat is not None:
|
||||
output = output + shared_out_flat.view(bsz, seqlen, hdim)
|
||||
return output, router_logits
|
||||
|
||||
Reference in New Issue
Block a user