Compare commits
2 Commits
textui
...
scattermoe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
936149380f | ||
|
|
86be9f329e |
@@ -36,6 +36,8 @@ SPARSE_MOE_BLOCK = {
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
|
||||
@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: OlmoeExperts (the real module)
|
||||
|
||||
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
|
||||
instead of ``gate_up_proj``.
|
||||
|
||||
This function walks the chain, collects LoRA params keyed by
|
||||
``parameter_name``, and returns the base experts module.
|
||||
|
||||
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
|
||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||
A/B are already in scattermoe layout.
|
||||
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
|
||||
"""
|
||||
# Collect ParamWrapper layers by their parameter_name
|
||||
wrappers = {}
|
||||
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
|
||||
num_experts = getattr(base_experts, "num_experts", None)
|
||||
if num_experts is None:
|
||||
# Fallback: infer from parameter shape
|
||||
gup = getattr(base_experts, "gate_up_proj", None)
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
for attr in ("gate_up_proj", "up_proj"):
|
||||
param = getattr(base_experts, attr, None)
|
||||
if param is not None:
|
||||
num_experts = param.shape[0]
|
||||
break
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
|
||||
if gup_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||
if lora_A is not None:
|
||||
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
Supports:
|
||||
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
|
||||
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
|
||||
# ====================================================================
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
@@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
|
||||
# ====================================================================
|
||||
is_gated = hasattr(experts, "gate_up_proj")
|
||||
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
|
||||
# ====================================================================
|
||||
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
|
||||
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
|
||||
|
||||
expert_input = hidden_states_flat
|
||||
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
|
||||
expert_input = fc1_latent_proj(hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
@@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and "gate_up_proj" in experts.parametrizations
|
||||
and up_proj_attr in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
@@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
gate_up_W = selective_expert_weights(
|
||||
up_W = selective_expert_weights(
|
||||
experts,
|
||||
"gate_up_proj",
|
||||
up_proj_attr,
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
).transpose(2, 1)
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
@@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# Up projection (gated: gate_up_proj; non-gated: up_proj)
|
||||
# ====================================================================
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear_lora(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
gup = parallel_linear(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
grouped_out=True,
|
||||
)
|
||||
|
||||
gates, h = gup.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
# ====================================================================
|
||||
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
|
||||
# ====================================================================
|
||||
if is_gated:
|
||||
gates, h = up_out.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
else:
|
||||
h = experts.act_fn(up_out)
|
||||
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
@@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection back to hidden_size (NemotronH)
|
||||
# ====================================================================
|
||||
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
|
||||
expert_output = fc2_latent_proj(expert_output)
|
||||
|
||||
# ====================================================================
|
||||
# Combine with shared expert and reshape
|
||||
# ====================================================================
|
||||
|
||||
@@ -1385,6 +1385,39 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip auto-enable for MoE models when native grouped_mm is unavailable
|
||||
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
|
||||
# with out= which bypasses autocast and fails on mixed dtypes during eval.
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
|
||||
if not has_grouped_mm:
|
||||
is_moe = False
|
||||
model_type = data.get("model_config_type", "")
|
||||
if model_type and "moe" in model_type.lower():
|
||||
is_moe = True
|
||||
if not is_moe:
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
base_model = data.get("base_model")
|
||||
if base_model:
|
||||
auto_cfg = AutoConfig.from_pretrained(
|
||||
base_model, trust_remote_code=False
|
||||
)
|
||||
if getattr(auto_cfg, "num_local_experts", None) or getattr(
|
||||
auto_cfg, "num_experts", None
|
||||
):
|
||||
is_moe = True
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if is_moe:
|
||||
return data
|
||||
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||
|
||||
@@ -176,24 +176,31 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
@@ -247,24 +254,31 @@ def test_lora_mlp_with_adapters(
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
@@ -334,25 +348,32 @@ def test_lora_qkv(sample_tensors):
|
||||
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
@@ -366,25 +387,32 @@ def test_lora_qkv(sample_tensors):
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
None,
|
||||
None, # Q
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
True,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
@@ -427,7 +455,9 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
output = LoRA_O.apply(
|
||||
X, None, W, b, None, A, B, scale, None, None
|
||||
) # X_drop, ..., lora_bias, magnitude
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -542,6 +572,7 @@ def test_inplace_operations(sample_tensors, apply_function):
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"training": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user