shared expert detection
This commit is contained in:
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -28,9 +27,15 @@ def _parse_int_list(value: str) -> List[int]:
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
|
||||
p.add_argument("--batch-sizes", default="4,8,16", help="Comma separated batch sizes")
|
||||
p.add_argument("--seq-lens", default="1024,2048", help="Comma separated sequence lengths")
|
||||
p.add_argument("--experts", default="8,16,32,64", help="Comma separated expert counts")
|
||||
p.add_argument(
|
||||
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
|
||||
)
|
||||
p.add_argument(
|
||||
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
|
||||
)
|
||||
p.add_argument(
|
||||
"--experts", default="8,16,32,64", help="Comma separated expert counts"
|
||||
)
|
||||
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
@@ -188,9 +193,7 @@ def _run_case(
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (
|
||||
(diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
)
|
||||
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
|
||||
tokens = bsz * seq
|
||||
flops = _estimate_flops(tokens, hidden, inter, top_k)
|
||||
@@ -215,10 +218,10 @@ def _run_case(
|
||||
)
|
||||
|
||||
|
||||
def _print_header(hidden: int, inter: int, dtype: torch.dtype, device: torch.device) -> None:
|
||||
print(
|
||||
f"Device={device} dtype={dtype} hidden={hidden} inter={inter}"
|
||||
)
|
||||
def _print_header(
|
||||
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
|
||||
) -> None:
|
||||
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
|
||||
print(
|
||||
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
||||
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
||||
|
||||
Reference in New Issue
Block a user