This commit is contained in:
Dan Saunders
2025-09-22 22:48:11 -04:00
parent 94cbc6d42d
commit 5b97633faa
2 changed files with 11 additions and 2 deletions

View File

@@ -142,7 +142,11 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
raise SystemExit("CUDA requested but not available")
baseline_module = build_module(args)
original_moe = DeepseekV3MoE.moe
original_moe = getattr(
DeepseekV3MoE,
"_axolotl_triton_original_moe",
DeepseekV3MoE.moe,
)
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()

View File

@@ -241,10 +241,15 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
# Record the unpatched implementation so callers can access a true baseline even
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
return
original_moe = DeepseekV3MoE.moe
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
def patched_moe(self, hidden_states, topk_indices, topk_weights):
try: