This commit is contained in:
Dan Saunders
2025-09-21 16:37:10 -04:00
parent 6a45d804f9
commit 18269ee6a9

View File

@@ -10,20 +10,78 @@ import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
_GROUP_SIZE_M = 128
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
return hidden_states.is_cuda and hidden_states.shape[0] > 0
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate_weights = [expert.gate_proj.weight for expert in module.experts]
up_weights = [expert.up_proj.weight for expert in module.experts]
down_weights = [expert.down_proj.weight for expert in module.experts]
gate = torch.stack(gate_weights, dim=0)
up = torch.stack(up_weights, dim=0)
down = torch.stack(down_weights, dim=0)
return gate, up, down
def _ensure_combined_expert_weights(
module, dtype: torch.dtype, device: torch.device
) -> None:
if not hasattr(module, "_axolotl_original_specs"):
module._axolotl_original_specs = {}
if getattr(module, "_axolotl_combined_weights", False):
# Move cached combined weights to the working dtype/device if required.
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
return
combined = {}
for name in _COMBINED_SUBMODULES:
weights = []
orig_device = None
orig_dtype = None
for expert in module.experts:
lin = expert.get_submodule(name)
weight_param = lin._parameters.get("weight")
if weight_param is None:
raise RuntimeError("Expected expert linear layers to have weights")
if orig_device is None:
orig_device = weight_param.device
orig_dtype = weight_param.dtype
weights.append(weight_param.detach().to(device=device, dtype=dtype))
if "weight" in lin._parameters:
del lin._parameters["weight"]
if "bias" in lin._parameters:
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
del lin._parameters["bias"]
combined[name] = torch.stack(weights, dim=0).contiguous()
module.register_parameter(
f"{name}_weight", torch.nn.Parameter(combined[name])
)
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
module._axolotl_combined_weights = True
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
def _restore_expert_weights(module) -> None:
if not getattr(module, "_axolotl_combined_weights", False):
return
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
combined = module._parameters.pop(param_name)
orig_device, orig_dtype = module._axolotl_original_specs.get(name, (combined.device, combined.dtype))
for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name)
lin._parameters["weight"] = torch.nn.Parameter(
combined[idx].detach().clone().to(orig_device, dtype=orig_dtype)
)
module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None
module._axolotl_combined_device = None
def _moe_triton_forward(
@@ -89,7 +147,11 @@ def _moe_triton_forward(
.contiguous()
)
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
_ensure_combined_expert_weights(module, hidden_dtype, device)
gate_weights = module.get_parameter("gate_proj_weight")
up_weights = module.get_parameter("up_proj_weight")
down_weights = module.get_parameter("down_proj_weight")
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
@@ -141,7 +203,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
original_moe = DeepseekV3MoE.moe
def patched_moe(self, hidden_states, topk_indices, topk_weights):
with contextlib.suppress(RuntimeError):
try:
return _moe_triton_forward(
self,
hidden_states,
@@ -150,7 +212,9 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
group_size_m,
original_moe,
)
return original_moe(self, hidden_states, topk_indices, topk_weights)
except RuntimeError:
_restore_expert_weights(self)
return original_moe(self, hidden_states, topk_indices, topk_weights)
DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True