Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
936149380f support nemotron for scattermoe-lora 2026-03-23 21:29:58 +00:00
Wing Lian
86be9f329e post merge lora fixes for CI (#3536) [skip ci]
* post merge lora fixes for CI

* handle lora kernel auto-enable for moe without grouped_mm

* prefer not to import torch in schema validation
2026-03-23 02:26:10 -04:00
4 changed files with 126 additions and 24 deletions

View File

@@ -36,6 +36,8 @@ SPARSE_MOE_BLOCK = {
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
"nemotron_h": "NemotronHMoE",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)

View File

@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
instead of ``gate_up_proj``.
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
for attr in ("gate_up_proj", "up_proj"):
param = getattr(base_experts, attr, None)
if param is not None:
num_experts = param.shape[0]
break
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
Supports:
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
"""
@staticmethod
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
# ====================================================================
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
@@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
# ====================================================================
is_gated = hasattr(experts, "gate_up_proj")
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
# ====================================================================
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
# ====================================================================
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
expert_input = hidden_states_flat
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
expert_input = fc1_latent_proj(hidden_states_flat)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
@@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module):
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
and up_proj_attr in experts.parametrizations
)
if use_selective:
@@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module):
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
up_W = selective_expert_weights(
experts,
"gate_up_proj",
up_proj_attr,
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
).transpose(2, 1)
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
@@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module):
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# Up projection (gated: gate_up_proj; non-gated: up_proj)
# ====================================================================
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
up_out = parallel_linear_lora(
expert_input,
up_W,
top_k,
sei_gup,
sorted_scattered_idxs,
@@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module):
use_fused_gather=True,
)
else:
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
up_out = parallel_linear(
expert_input,
up_W,
top_k,
sei_gup,
sorted_scattered_idxs,
@@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module):
grouped_out=True,
)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
# ====================================================================
if is_gated:
gates, h = up_out.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
else:
h = experts.act_fn(up_out)
# ====================================================================
# Down projection
@@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module):
gates=routing_weights,
)
# ====================================================================
# Optional latent projection back to hidden_size (NemotronH)
# ====================================================================
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
expert_output = fc2_latent_proj(expert_output)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================

View File

@@ -1385,6 +1385,39 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip auto-enable for MoE models when native grouped_mm is unavailable
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
# with out= which bypasses autocast and fails on mixed dtypes during eval.
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
if not has_grouped_mm:
is_moe = False
model_type = data.get("model_config_type", "")
if model_type and "moe" in model_type.lower():
is_moe = True
if not is_moe:
try:
from transformers import AutoConfig
base_model = data.get("base_model")
if base_model:
auto_cfg = AutoConfig.from_pretrained(
base_model, trust_remote_code=False
)
if getattr(auto_cfg, "num_local_experts", None) or getattr(
auto_cfg, "num_experts", None
):
is_moe = True
except Exception: # pylint: disable=broad-exception-caught
pass
if is_moe:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1

View File

@@ -176,24 +176,31 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
X.requires_grad = True
output = LoRA_MLP.apply(
X,
None, # X_drop
gate_proj.weight,
gate_proj.bias,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
None, # gate_lora_bias
None, # gate_magnitude
up_proj.weight,
up_proj.bias,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
None, # up_lora_bias
None, # up_magnitude
down_proj.weight,
down_proj.bias,
None, # down_quant
None, # down_A
None, # down_B
None, # down_scale
None, # down_lora_bias
None, # down_magnitude
activation_forward,
activation_backward,
True, # inplace
@@ -247,24 +254,31 @@ def test_lora_mlp_with_adapters(
# Forward pass with adapters
output = LoRA_MLP.apply(
X,
None, # X_drop
gate_proj.weight,
gate_proj.bias,
None,
gate_A,
gate_B,
scale,
None, # gate_lora_bias
None, # gate_magnitude
up_proj.weight,
up_proj.bias,
None,
up_A,
up_B,
scale,
None, # up_lora_bias
None, # up_magnitude
down_proj.weight,
down_proj.bias,
None,
down_A,
down_B,
scale,
None, # down_lora_bias
None, # down_magnitude
activation_forward,
activation_backward,
True,
@@ -334,25 +348,32 @@ def test_lora_qkv(sample_tensors):
Q1, K1, V1 = LoRA_QKV.apply(
X,
None, # X_drop
q_weight,
None,
None,
None,
None,
None,
None,
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
k_weight,
None,
None,
None,
None,
None,
None,
None, # K
v_weight,
None,
None,
None,
None,
None,
True,
None,
None, # V
True, # inplace
)
assert Q1.shape == K1.shape == V1.shape == X.shape
@@ -366,25 +387,32 @@ def test_lora_qkv(sample_tensors):
# Test with LoRA adapters
Q2, K2, V2 = LoRA_QKV.apply(
X,
None, # X_drop
q_weight,
None,
None,
q_A,
q_B,
scale,
None,
None, # Q
k_weight,
None,
None,
k_A,
k_B,
scale,
None,
None, # K
v_weight,
None,
None,
v_A,
v_B,
scale,
True,
None,
None, # V
True, # inplace
)
assert Q2.shape == K2.shape == V2.shape == X.shape
@@ -427,7 +455,9 @@ def test_lora_o(sample_tensors):
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, b, None, A, B, scale)
output = LoRA_O.apply(
X, None, W, b, None, A, B, scale, None, None
) # X_drop, ..., lora_bias, magnitude
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -542,6 +572,7 @@ def test_inplace_operations(sample_tensors, apply_function):
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
device="cuda", dtype=torch.float16
),
"training": False,
},
)