diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index 8002b3f79..10a3c0644 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -36,6 +36,8 @@ SPARSE_MOE_BLOCK = { "glm4v_moe": "Glm4vMoeTextMoE", # sigmoid -> topk routing (no group selection) "minimax_m2": "MiniMaxM2SparseMoeBlock", + # sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj) + "nemotron_h": "NemotronHMoE", # Models below need custom routing (not yet implemented): # "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk # "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index c6c01e255..ca76e5c4a 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module): -> base_layer: ParamWrapper(gate_up_proj) -> base_layer: OlmoeExperts (the real module) + For non-gated experts (e.g. NemotronH), the chain targets ``up_proj`` + instead of ``gate_up_proj``. + This function walks the chain, collects LoRA params keyed by ``parameter_name``, and returns the base experts module. @@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module): Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``. A/B are already in scattermoe layout. + For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA. """ # Collect ParamWrapper layers by their parameter_name wrappers = {} @@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module): num_experts = getattr(base_experts, "num_experts", None) if num_experts is None: # Fallback: infer from parameter shape - gup = getattr(base_experts, "gate_up_proj", None) - if gup is not None: - num_experts = gup.shape[0] + for attr in ("gate_up_proj", "up_proj"): + param = getattr(base_experts, attr, None) + if param is not None: + num_experts = param.shape[0] + break - # Extract gate_up_proj LoRA (needs A<->B swap due to transposition) + # Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition) gup_lora = None - gup_wrapper = wrappers.get("gate_up_proj") + gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj") if gup_wrapper is not None: lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) if lora_A is not None: @@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module): Supports: * **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax - * **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2 + * **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH * **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE) * **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``, extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel) + * **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj) + * **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts) """ @staticmethod @@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module): hidden_states_flat = layer_input.view(-1, hidden_dim) # ==================================================================== - # Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3) + # Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH) # ==================================================================== shared_expert_output = _compute_shared_expert(self, hidden_states_flat) @@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module): # ==================================================================== experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts) + # ==================================================================== + # Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only) + # ==================================================================== + is_gated = hasattr(experts, "gate_up_proj") + up_proj_attr = "gate_up_proj" if is_gated else "up_proj" + + # ==================================================================== + # Optional latent projection (NemotronH: fc1/fc2_latent_proj) + # ==================================================================== + fc1_latent_proj = getattr(self, "fc1_latent_proj", None) + fc2_latent_proj = getattr(self, "fc2_latent_proj", None) + + expert_input = hidden_states_flat + if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity): + expert_input = fc1_latent_proj(hidden_states_flat) + # ==================================================================== # Selective expert weight dequantization # ==================================================================== @@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module): use_selective = ( getattr(self, "_use_selective_dequant", False) and hasattr(experts, "parametrizations") - and "gate_up_proj" in experts.parametrizations + and up_proj_attr in experts.parametrizations ) if use_selective: @@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module): num_experts, ) # Dequantize only active experts' weights - gate_up_W = selective_expert_weights( + up_W = selective_expert_weights( experts, - "gate_up_proj", + up_proj_attr, active_experts, - ).transpose(2, 1) # [num_active, hidden, 2*inter] + ).transpose(2, 1) # Remap LoRA weights to match compact expert indices if gup_lora is not None: @@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module): sei_gup = remapped_expert_idxs eo_gup = compact_offsets else: - gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] + up_W = getattr(experts, up_proj_attr).transpose(2, 1) sei_gup = sorted_expert_idxs eo_gup = expert_offsets # ==================================================================== - # Gate + Up projection + # Up projection (gated: gate_up_proj; non-gated: up_proj) # ==================================================================== if gup_lora is not None: gup_A, gup_B, gup_scaling = gup_lora - gup = parallel_linear_lora( - hidden_states_flat, - gate_up_W, + up_out = parallel_linear_lora( + expert_input, + up_W, top_k, sei_gup, sorted_scattered_idxs, @@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module): use_fused_gather=True, ) else: - gup = parallel_linear( - hidden_states_flat, - gate_up_W, + up_out = parallel_linear( + expert_input, + up_W, top_k, sei_gup, sorted_scattered_idxs, @@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module): grouped_out=True, ) - gates, h = gup.chunk(2, dim=-1) - h = experts.act_fn(gates) * h + # ==================================================================== + # Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up)) + # ==================================================================== + if is_gated: + gates, h = up_out.chunk(2, dim=-1) + h = experts.act_fn(gates) * h + else: + h = experts.act_fn(up_out) # ==================================================================== # Down projection @@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module): gates=routing_weights, ) + # ==================================================================== + # Optional latent projection back to hidden_size (NemotronH) + # ==================================================================== + if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity): + expert_output = fc2_latent_proj(expert_output) + # ==================================================================== # Combine with shared expert and reshape # ====================================================================