Fix
This commit is contained in:
@@ -142,7 +142,11 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
|||||||
raise SystemExit("CUDA requested but not available")
|
raise SystemExit("CUDA requested but not available")
|
||||||
|
|
||||||
baseline_module = build_module(args)
|
baseline_module = build_module(args)
|
||||||
original_moe = DeepseekV3MoE.moe
|
original_moe = getattr(
|
||||||
|
DeepseekV3MoE,
|
||||||
|
"_axolotl_triton_original_moe",
|
||||||
|
DeepseekV3MoE.moe,
|
||||||
|
)
|
||||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||||
state_dict = baseline_module.state_dict()
|
state_dict = baseline_module.state_dict()
|
||||||
|
|
||||||
|
|||||||
@@ -241,10 +241,15 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
|||||||
|
|
||||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||||
|
|
||||||
|
# Record the unpatched implementation so callers can access a true baseline even
|
||||||
|
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||||
|
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||||
|
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
|
||||||
|
|
||||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||||
return
|
return
|
||||||
|
|
||||||
original_moe = DeepseekV3MoE.moe
|
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||||
|
|
||||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user