diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index 92d57b64c..2e9d53a32 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -10,20 +10,78 @@ import torch from axolotl.kernels.moe import ContiguousGroupedGEMM _GROUP_SIZE_M = 128 +_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj") def _is_triton_eligible(hidden_states: torch.Tensor) -> bool: return hidden_states.is_cuda and hidden_states.shape[0] > 0 -def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - gate_weights = [expert.gate_proj.weight for expert in module.experts] - up_weights = [expert.up_proj.weight for expert in module.experts] - down_weights = [expert.down_proj.weight for expert in module.experts] - gate = torch.stack(gate_weights, dim=0) - up = torch.stack(up_weights, dim=0) - down = torch.stack(down_weights, dim=0) - return gate, up, down +def _ensure_combined_expert_weights( + module, dtype: torch.dtype, device: torch.device +) -> None: + if not hasattr(module, "_axolotl_original_specs"): + module._axolotl_original_specs = {} + if getattr(module, "_axolotl_combined_weights", False): + # Move cached combined weights to the working dtype/device if required. + for name in _COMBINED_SUBMODULES: + param_name = f"{name}_weight" + param = module.get_parameter(param_name) + if param.device != device or param.dtype != dtype: + module._parameters[param_name] = torch.nn.Parameter( + param.to(device=device, dtype=dtype).contiguous() + ) + module._axolotl_combined_dtype = dtype + module._axolotl_combined_device = device + return + + combined = {} + for name in _COMBINED_SUBMODULES: + weights = [] + orig_device = None + orig_dtype = None + for expert in module.experts: + lin = expert.get_submodule(name) + weight_param = lin._parameters.get("weight") + if weight_param is None: + raise RuntimeError("Expected expert linear layers to have weights") + if orig_device is None: + orig_device = weight_param.device + orig_dtype = weight_param.dtype + weights.append(weight_param.detach().to(device=device, dtype=dtype)) + if "weight" in lin._parameters: + del lin._parameters["weight"] + if "bias" in lin._parameters: + # DeepseekV3 MLP layers are bias-free, but keep this for safety. + del lin._parameters["bias"] + combined[name] = torch.stack(weights, dim=0).contiguous() + module.register_parameter( + f"{name}_weight", torch.nn.Parameter(combined[name]) + ) + module._axolotl_original_specs[name] = (orig_device, orig_dtype) + + module._axolotl_combined_weights = True + module._axolotl_combined_dtype = dtype + module._axolotl_combined_device = device + + +def _restore_expert_weights(module) -> None: + if not getattr(module, "_axolotl_combined_weights", False): + return + + for name in _COMBINED_SUBMODULES: + param_name = f"{name}_weight" + combined = module._parameters.pop(param_name) + orig_device, orig_dtype = module._axolotl_original_specs.get(name, (combined.device, combined.dtype)) + for idx, expert in enumerate(module.experts): + lin = expert.get_submodule(name) + lin._parameters["weight"] = torch.nn.Parameter( + combined[idx].detach().clone().to(orig_device, dtype=orig_dtype) + ) + + module._axolotl_combined_weights = False + module._axolotl_combined_dtype = None + module._axolotl_combined_device = None def _moe_triton_forward( @@ -89,7 +147,11 @@ def _moe_triton_forward( .contiguous() ) - gate_weights, up_weights, down_weights = _collect_expert_weights(module) + _ensure_combined_expert_weights(module, hidden_dtype, device) + + gate_weights = module.get_parameter("gate_proj_weight") + up_weights = module.get_parameter("up_proj_weight") + down_weights = module.get_parameter("down_proj_weight") gate_out = ContiguousGroupedGEMM.apply( grouped_hidden, @@ -141,7 +203,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: original_moe = DeepseekV3MoE.moe def patched_moe(self, hidden_states, topk_indices, topk_weights): - with contextlib.suppress(RuntimeError): + try: return _moe_triton_forward( self, hidden_states, @@ -150,7 +212,9 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: group_size_m, original_moe, ) - return original_moe(self, hidden_states, topk_indices, topk_weights) + except RuntimeError: + _restore_expert_weights(self) + return original_moe(self, hidden_states, topk_indices, topk_weights) DeepseekV3MoE.moe = patched_moe DeepseekV3MoE._axolotl_triton_patch = True