diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 1045810a6..453d8002d 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -142,7 +142,11 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict: raise SystemExit("CUDA requested but not available") baseline_module = build_module(args) - original_moe = DeepseekV3MoE.moe + original_moe = getattr( + DeepseekV3MoE, + "_axolotl_triton_original_moe", + DeepseekV3MoE.moe, + ) baseline_module.moe = MethodType(original_moe, baseline_module) state_dict = baseline_module.state_dict() diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index e8a7408d8..d613a49f6 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -241,10 +241,15 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + # Record the unpatched implementation so callers can access a true baseline even + # after the Triton patch has been applied (e.g. repeated microbenchmarks). + if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"): + DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe + if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False): return - original_moe = DeepseekV3MoE.moe + original_moe = DeepseekV3MoE._axolotl_triton_original_moe def patched_moe(self, hidden_states, topk_indices, topk_weights): try: