add mg kernel backend

This commit is contained in:
Dan Saunders
2025-09-23 15:43:16 -04:00
parent 8a1f5ae940
commit d0da67eb17
9 changed files with 1753 additions and 77 deletions

View File

@@ -32,7 +32,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3
@@ -98,6 +98,12 @@ def parse_args() -> argparse.Namespace:
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args()
@@ -163,7 +169,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size)
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
@@ -250,6 +256,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
return {
"device": device,
"backend": args.backend,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
@@ -270,7 +277,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
result = benchmark_deepseek_v3(args)
print(
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"