Fix
This commit is contained in:
@@ -142,7 +142,11 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
raise SystemExit("CUDA requested but not available")
|
||||
|
||||
baseline_module = build_module(args)
|
||||
original_moe = DeepseekV3MoE.moe
|
||||
original_moe = getattr(
|
||||
DeepseekV3MoE,
|
||||
"_axolotl_triton_original_moe",
|
||||
DeepseekV3MoE.moe,
|
||||
)
|
||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||
state_dict = baseline_module.state_dict()
|
||||
|
||||
|
||||
@@ -241,10 +241,15 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
# Record the unpatched implementation so callers can access a true baseline even
|
||||
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
|
||||
|
||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE.moe
|
||||
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user