add mg kernel backend

This commit is contained in:
Dan Saunders
2025-09-23 15:43:16 -04:00
parent 8a1f5ae940
commit d0da67eb17
9 changed files with 1753 additions and 77 deletions

View File

@@ -32,7 +32,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3
@@ -98,6 +98,12 @@ def parse_args() -> argparse.Namespace:
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args()
@@ -163,7 +169,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size)
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
@@ -250,6 +256,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
return {
"device": device,
"backend": args.backend,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
@@ -270,7 +277,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
result = benchmark_deepseek_v3(args)
print(
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"

View File

@@ -21,7 +21,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import (
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
@@ -42,6 +42,12 @@ def parse_args() -> argparse.Namespace:
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
@@ -105,6 +111,7 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
{
"dtype": args.dtype,
"device": args.device,
"backend": args.backend,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
@@ -164,6 +171,7 @@ def main() -> None: # pragma: no cover - utility script
"n_experts",
"top_k",
"groups",
"backend",
"baseline_ms",
"patched_ms",
"speedup",
@@ -177,10 +185,10 @@ def main() -> None: # pragma: no cover - utility script
rows = []
print(
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}"
)
print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
)
@@ -206,6 +214,7 @@ def main() -> None: # pragma: no cover - utility script
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
args.backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
@@ -219,7 +228,7 @@ def main() -> None: # pragma: no cover - utility script
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
)