fix
This commit is contained in:
@@ -10,20 +10,78 @@ import torch
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
||||
|
||||
|
||||
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
gate_weights = [expert.gate_proj.weight for expert in module.experts]
|
||||
up_weights = [expert.up_proj.weight for expert in module.experts]
|
||||
down_weights = [expert.down_proj.weight for expert in module.experts]
|
||||
gate = torch.stack(gate_weights, dim=0)
|
||||
up = torch.stack(up_weights, dim=0)
|
||||
down = torch.stack(down_weights, dim=0)
|
||||
return gate, up, down
|
||||
def _ensure_combined_expert_weights(
|
||||
module, dtype: torch.dtype, device: torch.device
|
||||
) -> None:
|
||||
if not hasattr(module, "_axolotl_original_specs"):
|
||||
module._axolotl_original_specs = {}
|
||||
if getattr(module, "_axolotl_combined_weights", False):
|
||||
# Move cached combined weights to the working dtype/device if required.
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
param = module.get_parameter(param_name)
|
||||
if param.device != device or param.dtype != dtype:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.to(device=device, dtype=dtype).contiguous()
|
||||
)
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
return
|
||||
|
||||
combined = {}
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
weights = []
|
||||
orig_device = None
|
||||
orig_dtype = None
|
||||
for expert in module.experts:
|
||||
lin = expert.get_submodule(name)
|
||||
weight_param = lin._parameters.get("weight")
|
||||
if weight_param is None:
|
||||
raise RuntimeError("Expected expert linear layers to have weights")
|
||||
if orig_device is None:
|
||||
orig_device = weight_param.device
|
||||
orig_dtype = weight_param.dtype
|
||||
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
||||
if "weight" in lin._parameters:
|
||||
del lin._parameters["weight"]
|
||||
if "bias" in lin._parameters:
|
||||
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
|
||||
del lin._parameters["bias"]
|
||||
combined[name] = torch.stack(weights, dim=0).contiguous()
|
||||
module.register_parameter(
|
||||
f"{name}_weight", torch.nn.Parameter(combined[name])
|
||||
)
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
|
||||
|
||||
module._axolotl_combined_weights = True
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
|
||||
|
||||
def _restore_expert_weights(module) -> None:
|
||||
if not getattr(module, "_axolotl_combined_weights", False):
|
||||
return
|
||||
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
combined = module._parameters.pop(param_name)
|
||||
orig_device, orig_dtype = module._axolotl_original_specs.get(name, (combined.device, combined.dtype))
|
||||
for idx, expert in enumerate(module.experts):
|
||||
lin = expert.get_submodule(name)
|
||||
lin._parameters["weight"] = torch.nn.Parameter(
|
||||
combined[idx].detach().clone().to(orig_device, dtype=orig_dtype)
|
||||
)
|
||||
|
||||
module._axolotl_combined_weights = False
|
||||
module._axolotl_combined_dtype = None
|
||||
module._axolotl_combined_device = None
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
@@ -89,7 +147,11 @@ def _moe_triton_forward(
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device)
|
||||
|
||||
gate_weights = module.get_parameter("gate_proj_weight")
|
||||
up_weights = module.get_parameter("up_proj_weight")
|
||||
down_weights = module.get_parameter("down_proj_weight")
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
@@ -141,7 +203,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
original_moe = DeepseekV3MoE.moe
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
with contextlib.suppress(RuntimeError):
|
||||
try:
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -150,7 +212,9 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
group_size_m,
|
||||
original_moe,
|
||||
)
|
||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
||||
except RuntimeError:
|
||||
_restore_expert_weights(self)
|
||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
|
||||
Reference in New Issue
Block a user