shared expert detection

This commit is contained in:
Dan Saunders
2025-09-19 11:24:14 -04:00
parent bfc848f81d
commit 3bfed0aac8
2 changed files with 30 additions and 12 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import argparse
import csv
import math
import sys
import time
from dataclasses import dataclass
@@ -28,9 +27,15 @@ def _parse_int_list(value: str) -> List[int]:
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument("--batch-sizes", default="4,8,16", help="Comma separated batch sizes")
p.add_argument("--seq-lens", default="1024,2048", help="Comma separated sequence lengths")
p.add_argument("--experts", default="8,16,32,64", help="Comma separated expert counts")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
@@ -188,9 +193,7 @@ def _run_case(
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
)
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
@@ -215,10 +218,10 @@ def _run_case(
)
def _print_header(hidden: int, inter: int, dtype: torch.dtype, device: torch.device) -> None:
print(
f"Device={device} dtype={dtype} hidden={hidden} inter={inter}"
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"