add mg kernel backend
This commit is contained in:
@@ -32,7 +32,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
|
||||
|
||||
ACCURACY_TOLERANCE = 5e-3
|
||||
|
||||
@@ -98,6 +98,12 @@ def parse_args() -> argparse.Namespace:
|
||||
default=128,
|
||||
help="GROUP_SIZE_M used by the Triton kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["cg", "mg"],
|
||||
default="mg",
|
||||
help="MoE kernel backend to benchmark",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -163,7 +169,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||
state_dict = baseline_module.state_dict()
|
||||
|
||||
patch_deepseek_v3_moe(group_size_m=args.group_size)
|
||||
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
|
||||
patched_module = build_module(args)
|
||||
patched_module.load_state_dict(state_dict)
|
||||
|
||||
@@ -250,6 +256,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"backend": args.backend,
|
||||
"dtype": dtype,
|
||||
"baseline_ms": baseline_ms,
|
||||
"patched_ms": patched_ms,
|
||||
@@ -270,7 +277,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
result = benchmark_deepseek_v3(args)
|
||||
|
||||
print(
|
||||
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||
|
||||
Reference in New Issue
Block a user