Compare commits
2 Commits
textui
...
scattermoe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
936149380f | ||
|
|
86be9f329e |
@@ -36,6 +36,8 @@ SPARSE_MOE_BLOCK = {
|
|||||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||||
# sigmoid -> topk routing (no group selection)
|
# sigmoid -> topk routing (no group selection)
|
||||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||||
|
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
|
||||||
|
"nemotron_h": "NemotronHMoE",
|
||||||
# Models below need custom routing (not yet implemented):
|
# Models below need custom routing (not yet implemented):
|
||||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||||
|
|||||||
@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
-> base_layer: ParamWrapper(gate_up_proj)
|
-> base_layer: ParamWrapper(gate_up_proj)
|
||||||
-> base_layer: OlmoeExperts (the real module)
|
-> base_layer: OlmoeExperts (the real module)
|
||||||
|
|
||||||
|
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
|
||||||
|
instead of ``gate_up_proj``.
|
||||||
|
|
||||||
This function walks the chain, collects LoRA params keyed by
|
This function walks the chain, collects LoRA params keyed by
|
||||||
``parameter_name``, and returns the base experts module.
|
``parameter_name``, and returns the base experts module.
|
||||||
|
|
||||||
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
|
|
||||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||||
A/B are already in scattermoe layout.
|
A/B are already in scattermoe layout.
|
||||||
|
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
|
||||||
"""
|
"""
|
||||||
# Collect ParamWrapper layers by their parameter_name
|
# Collect ParamWrapper layers by their parameter_name
|
||||||
wrappers = {}
|
wrappers = {}
|
||||||
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
num_experts = getattr(base_experts, "num_experts", None)
|
num_experts = getattr(base_experts, "num_experts", None)
|
||||||
if num_experts is None:
|
if num_experts is None:
|
||||||
# Fallback: infer from parameter shape
|
# Fallback: infer from parameter shape
|
||||||
gup = getattr(base_experts, "gate_up_proj", None)
|
for attr in ("gate_up_proj", "up_proj"):
|
||||||
if gup is not None:
|
param = getattr(base_experts, attr, None)
|
||||||
num_experts = gup.shape[0]
|
if param is not None:
|
||||||
|
num_experts = param.shape[0]
|
||||||
|
break
|
||||||
|
|
||||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
|
||||||
gup_lora = None
|
gup_lora = None
|
||||||
gup_wrapper = wrappers.get("gate_up_proj")
|
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
|
||||||
if gup_wrapper is not None:
|
if gup_wrapper is not None:
|
||||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||||
if lora_A is not None:
|
if lora_A is not None:
|
||||||
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
Supports:
|
Supports:
|
||||||
|
|
||||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
|
||||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||||
|
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
|
||||||
|
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||||
|
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||||
|
|
||||||
@@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
# ====================================================================
|
# ====================================================================
|
||||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||||
|
|
||||||
|
# ====================================================================
|
||||||
|
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
|
||||||
|
# ====================================================================
|
||||||
|
is_gated = hasattr(experts, "gate_up_proj")
|
||||||
|
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
|
||||||
|
|
||||||
|
# ====================================================================
|
||||||
|
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
|
||||||
|
# ====================================================================
|
||||||
|
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
|
||||||
|
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
|
||||||
|
|
||||||
|
expert_input = hidden_states_flat
|
||||||
|
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
|
||||||
|
expert_input = fc1_latent_proj(hidden_states_flat)
|
||||||
|
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Selective expert weight dequantization
|
# Selective expert weight dequantization
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
@@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
use_selective = (
|
use_selective = (
|
||||||
getattr(self, "_use_selective_dequant", False)
|
getattr(self, "_use_selective_dequant", False)
|
||||||
and hasattr(experts, "parametrizations")
|
and hasattr(experts, "parametrizations")
|
||||||
and "gate_up_proj" in experts.parametrizations
|
and up_proj_attr in experts.parametrizations
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_selective:
|
if use_selective:
|
||||||
@@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
num_experts,
|
num_experts,
|
||||||
)
|
)
|
||||||
# Dequantize only active experts' weights
|
# Dequantize only active experts' weights
|
||||||
gate_up_W = selective_expert_weights(
|
up_W = selective_expert_weights(
|
||||||
experts,
|
experts,
|
||||||
"gate_up_proj",
|
up_proj_attr,
|
||||||
active_experts,
|
active_experts,
|
||||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
).transpose(2, 1)
|
||||||
|
|
||||||
# Remap LoRA weights to match compact expert indices
|
# Remap LoRA weights to match compact expert indices
|
||||||
if gup_lora is not None:
|
if gup_lora is not None:
|
||||||
@@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
sei_gup = remapped_expert_idxs
|
sei_gup = remapped_expert_idxs
|
||||||
eo_gup = compact_offsets
|
eo_gup = compact_offsets
|
||||||
else:
|
else:
|
||||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
|
||||||
sei_gup = sorted_expert_idxs
|
sei_gup = sorted_expert_idxs
|
||||||
eo_gup = expert_offsets
|
eo_gup = expert_offsets
|
||||||
|
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Gate + Up projection
|
# Up projection (gated: gate_up_proj; non-gated: up_proj)
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
if gup_lora is not None:
|
if gup_lora is not None:
|
||||||
gup_A, gup_B, gup_scaling = gup_lora
|
gup_A, gup_B, gup_scaling = gup_lora
|
||||||
gup = parallel_linear_lora(
|
up_out = parallel_linear_lora(
|
||||||
hidden_states_flat,
|
expert_input,
|
||||||
gate_up_W,
|
up_W,
|
||||||
top_k,
|
top_k,
|
||||||
sei_gup,
|
sei_gup,
|
||||||
sorted_scattered_idxs,
|
sorted_scattered_idxs,
|
||||||
@@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
use_fused_gather=True,
|
use_fused_gather=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
gup = parallel_linear(
|
up_out = parallel_linear(
|
||||||
hidden_states_flat,
|
expert_input,
|
||||||
gate_up_W,
|
up_W,
|
||||||
top_k,
|
top_k,
|
||||||
sei_gup,
|
sei_gup,
|
||||||
sorted_scattered_idxs,
|
sorted_scattered_idxs,
|
||||||
@@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
grouped_out=True,
|
grouped_out=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
gates, h = gup.chunk(2, dim=-1)
|
# ====================================================================
|
||||||
h = experts.act_fn(gates) * h
|
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
|
||||||
|
# ====================================================================
|
||||||
|
if is_gated:
|
||||||
|
gates, h = up_out.chunk(2, dim=-1)
|
||||||
|
h = experts.act_fn(gates) * h
|
||||||
|
else:
|
||||||
|
h = experts.act_fn(up_out)
|
||||||
|
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Down projection
|
# Down projection
|
||||||
@@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
gates=routing_weights,
|
gates=routing_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ====================================================================
|
||||||
|
# Optional latent projection back to hidden_size (NemotronH)
|
||||||
|
# ====================================================================
|
||||||
|
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
|
||||||
|
expert_output = fc2_latent_proj(expert_output)
|
||||||
|
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Combine with shared expert and reshape
|
# Combine with shared expert and reshape
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
|
|||||||
@@ -1385,6 +1385,39 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("trust_remote_code"):
|
if data.get("trust_remote_code"):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# Skip auto-enable for MoE models when native grouped_mm is unavailable
|
||||||
|
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
|
||||||
|
# with out= which bypasses autocast and fails on mixed dtypes during eval.
|
||||||
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
|
torch_version = env_capabilities.get("torch_version")
|
||||||
|
if torch_version is None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
|
||||||
|
if not has_grouped_mm:
|
||||||
|
is_moe = False
|
||||||
|
model_type = data.get("model_config_type", "")
|
||||||
|
if model_type and "moe" in model_type.lower():
|
||||||
|
is_moe = True
|
||||||
|
if not is_moe:
|
||||||
|
try:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
base_model = data.get("base_model")
|
||||||
|
if base_model:
|
||||||
|
auto_cfg = AutoConfig.from_pretrained(
|
||||||
|
base_model, trust_remote_code=False
|
||||||
|
)
|
||||||
|
if getattr(auto_cfg, "num_local_experts", None) or getattr(
|
||||||
|
auto_cfg, "num_experts", None
|
||||||
|
):
|
||||||
|
is_moe = True
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
pass
|
||||||
|
if is_moe:
|
||||||
|
return data
|
||||||
|
|
||||||
# Check multi-GPU compatibility
|
# Check multi-GPU compatibility
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||||
|
|||||||
@@ -176,24 +176,31 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
|||||||
X.requires_grad = True
|
X.requires_grad = True
|
||||||
output = LoRA_MLP.apply(
|
output = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
|
None, # X_drop
|
||||||
gate_proj.weight,
|
gate_proj.weight,
|
||||||
gate_proj.bias,
|
gate_proj.bias,
|
||||||
None, # gate_quant
|
None, # gate_quant
|
||||||
None, # gate_A
|
None, # gate_A
|
||||||
None, # gate_B
|
None, # gate_B
|
||||||
None, # gate_scale
|
None, # gate_scale
|
||||||
|
None, # gate_lora_bias
|
||||||
|
None, # gate_magnitude
|
||||||
up_proj.weight,
|
up_proj.weight,
|
||||||
up_proj.bias,
|
up_proj.bias,
|
||||||
None, # up_quant
|
None, # up_quant
|
||||||
None, # up_A
|
None, # up_A
|
||||||
None, # up_B
|
None, # up_B
|
||||||
None, # up_scale
|
None, # up_scale
|
||||||
|
None, # up_lora_bias
|
||||||
|
None, # up_magnitude
|
||||||
down_proj.weight,
|
down_proj.weight,
|
||||||
down_proj.bias,
|
down_proj.bias,
|
||||||
None, # down_quant
|
None, # down_quant
|
||||||
None, # down_A
|
None, # down_A
|
||||||
None, # down_B
|
None, # down_B
|
||||||
None, # down_scale
|
None, # down_scale
|
||||||
|
None, # down_lora_bias
|
||||||
|
None, # down_magnitude
|
||||||
activation_forward,
|
activation_forward,
|
||||||
activation_backward,
|
activation_backward,
|
||||||
True, # inplace
|
True, # inplace
|
||||||
@@ -247,24 +254,31 @@ def test_lora_mlp_with_adapters(
|
|||||||
# Forward pass with adapters
|
# Forward pass with adapters
|
||||||
output = LoRA_MLP.apply(
|
output = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
|
None, # X_drop
|
||||||
gate_proj.weight,
|
gate_proj.weight,
|
||||||
gate_proj.bias,
|
gate_proj.bias,
|
||||||
None,
|
None,
|
||||||
gate_A,
|
gate_A,
|
||||||
gate_B,
|
gate_B,
|
||||||
scale,
|
scale,
|
||||||
|
None, # gate_lora_bias
|
||||||
|
None, # gate_magnitude
|
||||||
up_proj.weight,
|
up_proj.weight,
|
||||||
up_proj.bias,
|
up_proj.bias,
|
||||||
None,
|
None,
|
||||||
up_A,
|
up_A,
|
||||||
up_B,
|
up_B,
|
||||||
scale,
|
scale,
|
||||||
|
None, # up_lora_bias
|
||||||
|
None, # up_magnitude
|
||||||
down_proj.weight,
|
down_proj.weight,
|
||||||
down_proj.bias,
|
down_proj.bias,
|
||||||
None,
|
None,
|
||||||
down_A,
|
down_A,
|
||||||
down_B,
|
down_B,
|
||||||
scale,
|
scale,
|
||||||
|
None, # down_lora_bias
|
||||||
|
None, # down_magnitude
|
||||||
activation_forward,
|
activation_forward,
|
||||||
activation_backward,
|
activation_backward,
|
||||||
True,
|
True,
|
||||||
@@ -334,25 +348,32 @@ def test_lora_qkv(sample_tensors):
|
|||||||
|
|
||||||
Q1, K1, V1 = LoRA_QKV.apply(
|
Q1, K1, V1 = LoRA_QKV.apply(
|
||||||
X,
|
X,
|
||||||
|
None, # X_drop
|
||||||
q_weight,
|
q_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||||
k_weight,
|
k_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None, # K
|
||||||
v_weight,
|
v_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
True,
|
None,
|
||||||
|
None, # V
|
||||||
|
True, # inplace
|
||||||
)
|
)
|
||||||
|
|
||||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||||
@@ -366,25 +387,32 @@ def test_lora_qkv(sample_tensors):
|
|||||||
# Test with LoRA adapters
|
# Test with LoRA adapters
|
||||||
Q2, K2, V2 = LoRA_QKV.apply(
|
Q2, K2, V2 = LoRA_QKV.apply(
|
||||||
X,
|
X,
|
||||||
|
None, # X_drop
|
||||||
q_weight,
|
q_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
q_A,
|
q_A,
|
||||||
q_B,
|
q_B,
|
||||||
scale,
|
scale,
|
||||||
|
None,
|
||||||
|
None, # Q
|
||||||
k_weight,
|
k_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
k_A,
|
k_A,
|
||||||
k_B,
|
k_B,
|
||||||
scale,
|
scale,
|
||||||
|
None,
|
||||||
|
None, # K
|
||||||
v_weight,
|
v_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
v_A,
|
v_A,
|
||||||
v_B,
|
v_B,
|
||||||
scale,
|
scale,
|
||||||
True,
|
None,
|
||||||
|
None, # V
|
||||||
|
True, # inplace
|
||||||
)
|
)
|
||||||
|
|
||||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||||
@@ -427,7 +455,9 @@ def test_lora_o(sample_tensors):
|
|||||||
|
|
||||||
# Test forward pass
|
# Test forward pass
|
||||||
X.requires_grad = True
|
X.requires_grad = True
|
||||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
output = LoRA_O.apply(
|
||||||
|
X, None, W, b, None, A, B, scale, None, None
|
||||||
|
) # X_drop, ..., lora_bias, magnitude
|
||||||
|
|
||||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||||
|
|
||||||
@@ -542,6 +572,7 @@ def test_inplace_operations(sample_tensors, apply_function):
|
|||||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||||
device="cuda", dtype=torch.float16
|
device="cuda", dtype=torch.float16
|
||||||
),
|
),
|
||||||
|
"training": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user