Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
936149380f support nemotron for scattermoe-lora 2026-03-23 21:29:58 +00:00
2 changed files with 59 additions and 21 deletions

View File

@@ -36,6 +36,8 @@ SPARSE_MOE_BLOCK = {
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
"nemotron_h": "NemotronHMoE",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)

View File

@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
instead of ``gate_up_proj``.
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
for attr in ("gate_up_proj", "up_proj"):
param = getattr(base_experts, attr, None)
if param is not None:
num_experts = param.shape[0]
break
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
Supports:
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
"""
@staticmethod
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
# ====================================================================
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
@@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
# ====================================================================
is_gated = hasattr(experts, "gate_up_proj")
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
# ====================================================================
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
# ====================================================================
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
expert_input = hidden_states_flat
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
expert_input = fc1_latent_proj(hidden_states_flat)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
@@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module):
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
and up_proj_attr in experts.parametrizations
)
if use_selective:
@@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module):
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
up_W = selective_expert_weights(
experts,
"gate_up_proj",
up_proj_attr,
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
).transpose(2, 1)
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
@@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module):
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# Up projection (gated: gate_up_proj; non-gated: up_proj)
# ====================================================================
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
up_out = parallel_linear_lora(
expert_input,
up_W,
top_k,
sei_gup,
sorted_scattered_idxs,
@@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module):
use_fused_gather=True,
)
else:
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
up_out = parallel_linear(
expert_input,
up_W,
top_k,
sei_gup,
sorted_scattered_idxs,
@@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module):
grouped_out=True,
)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
# ====================================================================
if is_gated:
gates, h = up_out.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
else:
h = experts.act_fn(up_out)
# ====================================================================
# Down projection
@@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module):
gates=routing_weights,
)
# ====================================================================
# Optional latent projection back to hidden_size (NemotronH)
# ====================================================================
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
expert_output = fc2_latent_proj(expert_output)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================