feat: add sonicmoe (#3411)
* feat: add sonicmoe * feat: add torch compile for routing * feat: add routing smoke test * feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe * fix: disable mlp kernel for sonicmoe too * feat: update to sonicmoe release * chore: update import following new sonicmoe changes * feat: update handling for blackwell * feat: add sonicmoe e2e test * fix: installation for updated sonicmoe * fix: git commit * fix: ignore py req and fix metadata * fix: increase min hidden size to match sonicmoe kernel min * fix: attempt properly interleave and handle unpatch mid-test * chore: refactor teardown better * chore: refactor to re-use rearrange * fix: add idempotency guard * fix: address comments on CI memory and interleave * fix: tests grad, param doublewrapped
This commit is contained in:
@@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface):
|
||||
}
|
||||
```
|
||||
|
||||
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
|
||||
In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -21,23 +21,55 @@ plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
|
||||
use_kernels: true
|
||||
|
||||
# Choose one (mutually exclusive):
|
||||
use_scattermoe: true
|
||||
# OR
|
||||
use_sonicmoe: true
|
||||
```
|
||||
|
||||
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
|
||||
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
|
||||
|
||||
### SonicMoE installation
|
||||
|
||||
**Prerequisites:**
|
||||
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
|
||||
- CUDA 12.9+ (13.0+ for B300)
|
||||
- PyTorch 2.7+ (2.9.1 recommended)
|
||||
- For B300: Triton 3.6.0
|
||||
|
||||
```bash
|
||||
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
|
||||
```
|
||||
|
||||
See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.
|
||||
|
||||
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
|
||||
|
||||
## How It Works
|
||||
|
||||
The `KernelsPlugin` runs before model loading and:
|
||||
|
||||
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
|
||||
### ScatterMoE
|
||||
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
||||
|
||||
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
|
||||
### SonicMoE
|
||||
1. Resolves the model's MoE block class(es) from `constants.py`.
|
||||
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports both softmax->topk and sigmoid->topk routing strategies.
|
||||
|
||||
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
|
||||
|
||||
#### Supported Models
|
||||
|
||||
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
|
||||
|
||||
## Limitations
|
||||
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
|
||||
|
||||
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
|
||||
|
||||
@@ -6,7 +6,18 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class KernelsArgs(BaseModel):
|
||||
use_scattermoe: bool | None = True
|
||||
use_scattermoe: bool | None = None
|
||||
use_sonicmoe: bool | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_mutually_exclusive(cls, data):
|
||||
if data.get("use_scattermoe") and data.get("use_sonicmoe"):
|
||||
raise ValueError(
|
||||
"Cannot use both ScatterMoE and SonicMoE simultaneously. "
|
||||
"Please set only one of `use_scattermoe` or `use_sonicmoe` to true."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -36,11 +47,11 @@ class KernelsArgs(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel_scattermoe(cls, data):
|
||||
if data.get("use_scattermoe") is True:
|
||||
def disable_mlp_kernel(cls, data):
|
||||
if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True:
|
||||
if data.get("lora_mlp_kernel") is True:
|
||||
LOG.warning(
|
||||
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
|
||||
"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues."
|
||||
)
|
||||
data["lora_mlp_kernel"] = False
|
||||
data["mlp_kernel"] = False
|
||||
|
||||
68
src/axolotl/integrations/kernels/constants.py
Normal file
68
src/axolotl/integrations/kernels/constants.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Supported MoE block mappings for kernel integrations.
|
||||
|
||||
Maps model_type to the SparseMoeBlock class name(s) in transformers.
|
||||
Used by both ScatterMoE and SonicMoE kernel paths.
|
||||
|
||||
Values can be a single class name (str) or a list of class names for models
|
||||
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
SPARSE_MOE_BLOCK = {
|
||||
# softmax -> topk routing
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
||||
"qwen3_omni_moe": [
|
||||
"Qwen3OmniMoeThinkerTextSparseMoeBlock",
|
||||
"Qwen3OmniMoeTalkerTextSparseMoeBlock",
|
||||
],
|
||||
"olmoe": "OlmoeSparseMoeBlock",
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"minimax": "MiniMaxSparseMoeBlock",
|
||||
# sigmoid -> topk routing (with group-based expert selection)
|
||||
"glm_moe_dsa": "GlmMoeDsaMoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"glm4_moe": "Glm4MoeMoE",
|
||||
"glm4_moe_lite": "Glm4MoeLiteMoE",
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
|
||||
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
|
||||
}
|
||||
|
||||
|
||||
def resolve_moe_block_classes(model_type: str):
|
||||
"""Resolve all MoE block classes from transformers for the given model type.
|
||||
|
||||
Returns a list of classes (one for most models, multiple for models with
|
||||
distinct MoE block types like qwen3_omni_moe).
|
||||
"""
|
||||
entry = SPARSE_MOE_BLOCK.get(model_type)
|
||||
if entry is None:
|
||||
raise ValueError(
|
||||
f"Unsupported MoE model type '{model_type}'. "
|
||||
f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}"
|
||||
)
|
||||
|
||||
cls_names = entry if isinstance(entry, list) else [entry]
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
classes = []
|
||||
for cls_name in cls_names:
|
||||
moe_cls = getattr(module, cls_name, None)
|
||||
if moe_cls is None:
|
||||
raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'")
|
||||
classes.append(moe_cls)
|
||||
|
||||
return classes
|
||||
@@ -1,14 +1,59 @@
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from kernels import (
|
||||
LocalLayerRepository,
|
||||
Mode,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
)
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _check_sonicmoe_gpu_compat():
|
||||
"""Validate GPU compute capability for SonicMoE and configure env.
|
||||
|
||||
Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).
|
||||
B300 (sm_103) additionally requires Triton 3.6.0.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
cc = torch.cuda.get_device_capability()
|
||||
|
||||
if cc < (9, 0):
|
||||
raise RuntimeError(
|
||||
f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, "
|
||||
f"but detected sm_{cc[0]}{cc[1]}."
|
||||
)
|
||||
|
||||
if cc > (10, 3):
|
||||
raise RuntimeError(
|
||||
f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. "
|
||||
f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)."
|
||||
)
|
||||
|
||||
# Blackwell (sm_100+): enable QuACK GEMM kernels
|
||||
if cc >= (10, 0):
|
||||
os.environ.setdefault("USE_QUACK_GEMM", "1")
|
||||
LOG.info(
|
||||
f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1"
|
||||
)
|
||||
|
||||
# B300 (sm_103): requires Triton 3.6.0
|
||||
if cc == (10, 3):
|
||||
triton_spec = importlib.util.find_spec("triton")
|
||||
if triton_spec is None:
|
||||
raise RuntimeError(
|
||||
"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed."
|
||||
)
|
||||
import triton
|
||||
|
||||
triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2])
|
||||
if triton_version != (3, 6):
|
||||
raise RuntimeError(
|
||||
f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}."
|
||||
)
|
||||
|
||||
|
||||
class KernelsPlugin(BasePlugin):
|
||||
@@ -19,8 +64,32 @@ class KernelsPlugin(BasePlugin):
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
self._kernelize_model(cfg.model_config_type)
|
||||
elif cfg.use_sonicmoe:
|
||||
if not importlib.util.find_spec("sonicmoe"):
|
||||
raise RuntimeError(
|
||||
"SonicMoE is not installed. See installation instructions at "
|
||||
"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation"
|
||||
)
|
||||
|
||||
_check_sonicmoe_gpu_compat()
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(
|
||||
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
|
||||
)
|
||||
patch_sonicmoe(
|
||||
cfg.model_config_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
)
|
||||
|
||||
def _register_kernels(self):
|
||||
from kernels import (
|
||||
LocalLayerRepository,
|
||||
Mode,
|
||||
register_kernel_mapping,
|
||||
)
|
||||
|
||||
plugin_root = Path(__file__).parent
|
||||
register_kernel_mapping(
|
||||
{
|
||||
@@ -42,25 +111,11 @@ class KernelsPlugin(BasePlugin):
|
||||
)
|
||||
|
||||
def _kernelize_model(self, model_type: str):
|
||||
if model_type == "olmoe":
|
||||
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
|
||||
from kernels import replace_kernel_forward_from_hub
|
||||
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
|
||||
for model_moe_cls in resolve_moe_block_classes(model_type):
|
||||
replace_kernel_forward_from_hub(
|
||||
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
|
||||
model_moe_cls, "HFScatterMoEParallelExperts"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_moe_cls = get_model_moe_block(model_type)
|
||||
replace_kernel_forward_from_hub(
|
||||
model_moe_cls, "HFScatterMoEParallelExperts"
|
||||
)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Unsupported model type: {model_type}") from err
|
||||
|
||||
|
||||
def get_model_moe_block(model_type: str):
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
|
||||
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
|
||||
return model_cls
|
||||
|
||||
3
src/axolotl/integrations/kernels/sonicmoe/__init__.py
Normal file
3
src/axolotl/integrations/kernels/sonicmoe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .patch import patch_sonicmoe
|
||||
|
||||
__all__ = ["patch_sonicmoe"]
|
||||
213
src/axolotl/integrations/kernels/sonicmoe/patch.py
Normal file
213
src/axolotl/integrations/kernels/sonicmoe/patch.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
SonicMoE patching for SparseMoeBlock forward pass.
|
||||
|
||||
Monkeypatches the SparseMoeBlock class for a given model type to use
|
||||
SonicMoE's optimized kernels. Two forward paths are supported:
|
||||
|
||||
1. **General routing path** (routing_fn is not None):
|
||||
Uses a custom routing function + ``moe_general_routing_inputs``.
|
||||
Suitable for models with non-standard routing (softmax->topk, sigmoid->topk).
|
||||
|
||||
2. **Fused topk->softmax path** (routing_fn is None):
|
||||
Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation.
|
||||
Suitable for models with simple topk->softmax routing.
|
||||
|
||||
Weight format conversion (interleave/deinterleave) is handled by the
|
||||
WeightConverter system, so the forward assumes weights are already in
|
||||
interleaved format.
|
||||
|
||||
Shared experts are handled generically: if the block has a ``shared_expert``
|
||||
or ``shared_experts`` attribute, its output is computed alongside the routed
|
||||
experts and added to the final output. An optional ``shared_expert_gate``
|
||||
applies sigmoid gating to the shared expert contribution.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
|
||||
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
|
||||
|
||||
Args:
|
||||
model_type: The HuggingFace model type (e.g. "qwen3_moe").
|
||||
torch_compile: If True, wrap routing functions with torch.compile
|
||||
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
|
||||
"""
|
||||
from .routing import get_model_moe_config
|
||||
from .weight_converter import register_sonicmoe_weight_converter
|
||||
|
||||
routing_fn, activation, router_attr = get_model_moe_config(model_type)
|
||||
|
||||
if torch_compile and routing_fn is not None:
|
||||
routing_fn = _try_compile_routing(routing_fn)
|
||||
|
||||
for moe_cls in resolve_moe_block_classes(model_type):
|
||||
_patch_forward(moe_cls, routing_fn, activation, router_attr)
|
||||
register_sonicmoe_weight_converter(model_type)
|
||||
|
||||
|
||||
def _try_compile_routing(routing_fn):
|
||||
"""Attempt to torch.compile the routing function, fall back to eager on failure."""
|
||||
try:
|
||||
compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False)
|
||||
LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}")
|
||||
return compiled_fn
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
LOG.warning(
|
||||
f"torch.compile failed for routing function {routing_fn.__name__}, "
|
||||
f"falling back to eager: {exc}"
|
||||
)
|
||||
return routing_fn
|
||||
|
||||
|
||||
def _patch_forward(moe_cls, routing_fn, activation, router_attr):
|
||||
"""Monkeypatch the SparseMoeBlock class with a SonicMoE forward.
|
||||
|
||||
The patched forward handles shared experts generically: if
|
||||
``self.shared_expert`` or ``self.shared_experts`` exists, it is computed
|
||||
and added to the routed output. If ``self.shared_expert_gate`` also exists,
|
||||
it applies sigmoid gating to the shared expert contribution (as in qwen2_moe).
|
||||
|
||||
Args:
|
||||
moe_cls: The SparseMoeBlock class to patch.
|
||||
routing_fn: Routing function (e.g. softmax_topk_routing), or None
|
||||
for the fused moe_TC_softmax_topk_layer path.
|
||||
activation: SonicMoE ActivationType enum value.
|
||||
router_attr: Name of the router module attribute on the MoE block.
|
||||
"""
|
||||
if hasattr(moe_cls, "_original_forward"):
|
||||
LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping")
|
||||
return
|
||||
|
||||
original_forward = moe_cls.forward
|
||||
|
||||
if routing_fn is not None:
|
||||
_make_general_forward(moe_cls, routing_fn, activation)
|
||||
else:
|
||||
_make_fused_forward(moe_cls, activation, router_attr)
|
||||
|
||||
moe_cls._original_forward = original_forward
|
||||
LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation")
|
||||
|
||||
|
||||
def _make_general_forward(moe_cls, routing_fn, activation):
|
||||
"""Create forward using routing_fn + moe_general_routing_inputs."""
|
||||
|
||||
def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
from sonicmoe import moe_general_routing_inputs
|
||||
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Shared expert (computed early, matching original model ordering)
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
# Routing
|
||||
router_scores, token_indices, expert_indices, _router_logits = routing_fn(
|
||||
hidden_states_flat, self
|
||||
)
|
||||
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
E = gate_up_weight.shape[-1]
|
||||
|
||||
output, _ = moe_general_routing_inputs(
|
||||
hidden_states_flat,
|
||||
router_scores,
|
||||
token_indices,
|
||||
expert_indices,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
E,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
activation,
|
||||
False, # is_inference_mode
|
||||
)
|
||||
|
||||
# Add shared expert contribution if present
|
||||
if shared_expert_output is not None:
|
||||
if hasattr(self, "shared_expert_gate"):
|
||||
shared_expert_output = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
|
||||
* shared_expert_output
|
||||
)
|
||||
output = output + shared_expert_output
|
||||
|
||||
return output.view(batch_size, sequence_length, hidden_dim)
|
||||
|
||||
moe_cls.forward = sonicmoe_forward
|
||||
|
||||
|
||||
def _make_fused_forward(moe_cls, activation, router_attr):
|
||||
"""Create forward using moe_TC_softmax_topk_layer (topk -> softmax)."""
|
||||
|
||||
def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
from sonicmoe import moe_TC_softmax_topk_layer
|
||||
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Shared expert (computed early, matching original model ordering)
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
router = getattr(self, router_attr)
|
||||
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
|
||||
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
|
||||
hidden_states_flat,
|
||||
router.weight,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
router.top_k,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
activation,
|
||||
False, # is_inference_mode
|
||||
)
|
||||
|
||||
# Add shared expert contribution if present
|
||||
if shared_expert_output is not None:
|
||||
if hasattr(self, "shared_expert_gate"):
|
||||
shared_expert_output = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
|
||||
* shared_expert_output
|
||||
)
|
||||
output = output + shared_expert_output
|
||||
|
||||
return output.view(batch_size, sequence_length, hidden_dim)
|
||||
|
||||
moe_cls.forward = sonicmoe_fused_forward
|
||||
|
||||
|
||||
def _compute_shared_expert(moe_block, hidden_states_flat):
|
||||
"""Compute shared expert output if the block has one.
|
||||
|
||||
Handles singular (qwen2_moe: ``shared_expert``), plural
|
||||
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
|
||||
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
|
||||
"""
|
||||
shared_expert = (
|
||||
getattr(moe_block, "shared_expert", None)
|
||||
or getattr(moe_block, "shared_experts", None)
|
||||
or getattr(moe_block, "shared_mlp", None)
|
||||
)
|
||||
if shared_expert is not None:
|
||||
return shared_expert(hidden_states_flat)
|
||||
return None
|
||||
219
src/axolotl/integrations/kernels/sonicmoe/routing.py
Normal file
219
src/axolotl/integrations/kernels/sonicmoe/routing.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Routing functions for SonicMoE integration.
|
||||
|
||||
Different MoE architectures use different routing strategies:
|
||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_model_moe_config(model_type: str):
|
||||
"""Returns (routing_fn, activation, router_attr) for a given model type.
|
||||
|
||||
Args:
|
||||
model_type: HuggingFace model type string.
|
||||
|
||||
Returns:
|
||||
routing_fn: Callable or None. None signals the fused
|
||||
moe_TC_softmax_topk_layer path (topk -> softmax models).
|
||||
activation: SonicMoE ActivationType enum value.
|
||||
router_attr: Name of the router module attribute on the MoE block
|
||||
(e.g. "gate" or "router").
|
||||
|
||||
The activation type cannot be derived from config.hidden_act because
|
||||
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
|
||||
(act_fn(gate) * up pattern). So we specify it per model type.
|
||||
"""
|
||||
from sonicmoe.enums import ActivationType
|
||||
|
||||
if model_type in (
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
"qwen3_5_moe",
|
||||
"qwen3_next",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_omni_moe",
|
||||
"olmoe",
|
||||
"mixtral",
|
||||
"minimax",
|
||||
):
|
||||
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in (
|
||||
"glm_moe_dsa",
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
"glm4_moe_lite",
|
||||
"glm4v_moe",
|
||||
"minimax_m2",
|
||||
):
|
||||
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("ernie4_5_moe",):
|
||||
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("deepseek_v2",):
|
||||
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
|
||||
# # (not n_group), gate is nn.Linear (not a router class).
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("hunyuan_v1_moe",):
|
||||
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
|
||||
# # top_k on block not gate, creates scatter routing matrix.
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# Fused topk -> softmax path (routing_fn=None):
|
||||
# elif model_type in ("gpt_oss",):
|
||||
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
||||
# # ignores (it only takes router_w, not bias). Also has transposed
|
||||
# # weight layout [E, H, 2*I] and custom GLU activation.
|
||||
# return None, ActivationType.SWIGLU, "router"
|
||||
else:
|
||||
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
|
||||
|
||||
|
||||
def softmax_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.*)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = gate.top_k
|
||||
|
||||
# Compute router logits and softmax over all experts
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts per token
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
# top_values = top_values.to(router_probs.dtype)
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||
# Expert sorting is handled internally by general_routing_router_metadata.
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def sigmoid_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
|
||||
|
||||
Supports two variants:
|
||||
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
|
||||
bias on gate, group-based masking before topk.
|
||||
- **No group selection** (minimax_m2): n_group == 1 (or absent),
|
||||
bias on moe_block, straight topk from all experts.
|
||||
|
||||
Final routing weights come from the original sigmoid scores (not
|
||||
bias-corrected), with optional renormalization and scaling.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.* and
|
||||
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
|
||||
.routed_scaling_factor, .n_routed_experts)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
# Compute router logits and sigmoid probabilities
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
|
||||
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
raise AttributeError(
|
||||
f"sigmoid_topk_routing requires e_score_correction_bias on "
|
||||
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [T, n_group]
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout.
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
181
src/axolotl/integrations/kernels/sonicmoe/weight_converter.py
Normal file
181
src/axolotl/integrations/kernels/sonicmoe/weight_converter.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Custom WeightConverter operations for SonicMoE weight format conversion.
|
||||
|
||||
SonicMoE requires gate_up_proj weights in interleaved format:
|
||||
- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up
|
||||
- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...]
|
||||
|
||||
These ConversionOps integrate with transformers' WeightConverter system so that
|
||||
weights are transparently converted during loading and reverted during saving.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from transformers.core_model_loading import ConversionOps
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension."""
|
||||
return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2)
|
||||
|
||||
|
||||
def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension."""
|
||||
return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2)
|
||||
|
||||
|
||||
class ConcatenatedToInterleaved(ConversionOps):
|
||||
"""Convert concatenated gate/up projections to interleaved format.
|
||||
|
||||
Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
|
||||
Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
|
||||
|
||||
This operation is applied along ``dim`` (default 1, the 2*I dimension).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1):
|
||||
self.dim = dim
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(
|
||||
self,
|
||||
input_dict: dict[str, Any],
|
||||
source_patterns: list[str],
|
||||
target_patterns: list[str],
|
||||
**kwargs,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
target_pattern = self._get_target_pattern(
|
||||
input_dict, source_patterns, target_patterns
|
||||
)
|
||||
tensors = next(iter(input_dict.values()))
|
||||
tensor = tensors[0] if isinstance(tensors, list) else tensors
|
||||
|
||||
interleaved = interleave_gate_up(tensor)
|
||||
|
||||
return {target_pattern: interleaved}
|
||||
|
||||
def _get_target_pattern(
|
||||
self,
|
||||
input_dict: dict[str, Any],
|
||||
source_patterns: list[str],
|
||||
target_patterns: list[str],
|
||||
) -> str:
|
||||
# Follow the same logic as Transpose.get_target_pattern
|
||||
if len(input_dict) != 1:
|
||||
raise ValueError("Undefined Operation encountered!")
|
||||
if len(target_patterns) > 1:
|
||||
if len(source_patterns) == 1:
|
||||
return source_patterns[0]
|
||||
raise ValueError("Undefined Operation encountered!")
|
||||
return target_patterns[0]
|
||||
|
||||
@property
|
||||
def reverse_op(self) -> ConversionOps:
|
||||
return InterleavedToConcatenated(self.dim)
|
||||
|
||||
|
||||
class InterleavedToConcatenated(ConversionOps):
|
||||
"""Convert interleaved gate/up projections back to concatenated format.
|
||||
|
||||
Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
|
||||
Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
|
||||
|
||||
This is the reverse of ``ConcatenatedToInterleaved``.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1):
|
||||
self.dim = dim
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(
|
||||
self,
|
||||
input_dict: dict[str, Any],
|
||||
source_patterns: list[str],
|
||||
target_patterns: list[str],
|
||||
**kwargs,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
target_pattern = self._get_target_pattern(
|
||||
input_dict, source_patterns, target_patterns
|
||||
)
|
||||
tensors = next(iter(input_dict.values()))
|
||||
tensor = tensors[0] if isinstance(tensors, list) else tensors
|
||||
|
||||
concatenated = deinterleave_gate_up(tensor)
|
||||
|
||||
return {target_pattern: concatenated}
|
||||
|
||||
def _get_target_pattern(
|
||||
self,
|
||||
input_dict: dict[str, Any],
|
||||
source_patterns: list[str],
|
||||
target_patterns: list[str],
|
||||
) -> str:
|
||||
if len(input_dict) != 1:
|
||||
raise ValueError("Undefined Operation encountered!")
|
||||
if len(target_patterns) > 1:
|
||||
if len(source_patterns) == 1:
|
||||
return source_patterns[0]
|
||||
raise ValueError("Undefined Operation encountered!")
|
||||
return target_patterns[0]
|
||||
|
||||
@property
|
||||
def reverse_op(self) -> ConversionOps:
|
||||
return ConcatenatedToInterleaved(self.dim)
|
||||
|
||||
|
||||
def register_sonicmoe_weight_converter(model_type: str):
|
||||
"""Override the conversion mapping to add interleave step for gate_up_proj.
|
||||
|
||||
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
|
||||
converter chain. For example, qwen3_moe's chain becomes:
|
||||
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
|
||||
|
||||
The reverse is auto-generated for saving:
|
||||
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
|
||||
"""
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
register_checkpoint_conversion_mapping,
|
||||
)
|
||||
|
||||
existing = get_checkpoint_conversion_mapping(model_type)
|
||||
if existing is None:
|
||||
LOG.warning(
|
||||
f"No conversion mapping found for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
return
|
||||
|
||||
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
|
||||
patched = False
|
||||
for converter in existing:
|
||||
if hasattr(converter, "operations") and any(
|
||||
"gate_up_proj" in pat for pat in converter.target_patterns
|
||||
):
|
||||
# Guard against double registration (e.g. plugin reloaded)
|
||||
if any(
|
||||
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
|
||||
):
|
||||
LOG.info(
|
||||
f"SonicMoE weight converter already registered for '{model_type}'"
|
||||
)
|
||||
return
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
patched = True
|
||||
break
|
||||
|
||||
if not patched:
|
||||
LOG.warning(
|
||||
f"Could not find gate_up_proj converter for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
return
|
||||
|
||||
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
288
tests/e2e/integrations/test_sonicmoe.py
Normal file
288
tests/e2e/integrations/test_sonicmoe.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
End-to-end gradient and convergence tests for SonicMoE integration.
|
||||
|
||||
Requires:
|
||||
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
|
||||
- sonicmoe package installed
|
||||
- transformers with Qwen3MoE support
|
||||
|
||||
Usage:
|
||||
pytest tests/e2e/integrations/test_sonicmoe.py -v -s
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
|
||||
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
|
||||
pytest.mark.skipif(
|
||||
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
|
||||
),
|
||||
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
|
||||
]
|
||||
|
||||
|
||||
def _create_tiny_qwen3_config():
|
||||
"""Create a minimal Qwen3MoE config for fast testing."""
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.for_model("qwen3_moe")
|
||||
config.hidden_size = 512
|
||||
config.intermediate_size = 1024
|
||||
config.moe_intermediate_size = 64
|
||||
config.num_attention_heads = 16
|
||||
config.num_key_value_heads = 2
|
||||
config.head_dim = 32
|
||||
config.num_hidden_layers = 2
|
||||
config.num_experts = 8
|
||||
config.num_experts_per_tok = 2
|
||||
config.vocab_size = 1000
|
||||
config.max_position_embeddings = 128
|
||||
config.norm_topk_prob = True
|
||||
config.torch_dtype = torch.bfloat16
|
||||
return config
|
||||
|
||||
|
||||
def _interleave_gate_up_weights(model):
|
||||
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
interleave_gate_up,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for name, param in model.named_parameters():
|
||||
if "gate_up_proj" in name:
|
||||
param.copy_(interleave_gate_up(param))
|
||||
|
||||
|
||||
def _unpatch_sonicmoe():
|
||||
"""Restore original forward on the MoE block class if it was patched."""
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
|
||||
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
|
||||
if hasattr(moe_cls, "_original_forward"):
|
||||
moe_cls.forward = moe_cls._original_forward
|
||||
del moe_cls._original_forward
|
||||
|
||||
|
||||
class TestSonicMoEForwardCorrectness:
|
||||
"""Verify SonicMoE-patched model produces same output as original."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_forward_output_matches(self):
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
|
||||
# Original model
|
||||
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
|
||||
with torch.no_grad():
|
||||
out_orig = model_orig(input_ids)
|
||||
|
||||
# Patched model (same weights, interleaved for SonicMoE)
|
||||
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
model_patched.load_state_dict(model_orig.state_dict())
|
||||
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model_patched)
|
||||
|
||||
with torch.no_grad():
|
||||
out_patched = model_patched(input_ids)
|
||||
|
||||
max_diff = (out_orig.logits - out_patched.logits).abs().max().item()
|
||||
assert torch.allclose(
|
||||
out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1
|
||||
), f"Output mismatch: max diff={max_diff:.6f}"
|
||||
|
||||
|
||||
class TestSonicMoEGradientCorrectness:
|
||||
"""Compare gradients between original HuggingFace and SonicMoE-patched forward."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_gradients_match(self):
|
||||
"""Verify all parameter gradients match between original and patched."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
deinterleave_gate_up,
|
||||
)
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
|
||||
# ---------- Original model ----------
|
||||
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
out_orig = model_orig(input_ids, labels=input_ids)
|
||||
out_orig.loss.backward()
|
||||
grads_orig = {
|
||||
n: p.grad.float().clone()
|
||||
for n, p in model_orig.named_parameters()
|
||||
if p.grad is not None
|
||||
}
|
||||
loss_orig = out_orig.loss.item()
|
||||
|
||||
# ---------- SonicMoE-patched model (same weights, interleaved) ----------
|
||||
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
model_patched.load_state_dict(model_orig.state_dict())
|
||||
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model_patched)
|
||||
|
||||
out_patched = model_patched(input_ids, labels=input_ids)
|
||||
out_patched.loss.backward()
|
||||
grads_patched = {}
|
||||
for n, p in model_patched.named_parameters():
|
||||
if p.grad is None:
|
||||
continue
|
||||
g = p.grad.float().clone()
|
||||
# gate_up_proj grads are in interleaved layout, de-interleave to match orig
|
||||
if "gate_up_proj" in n:
|
||||
g = deinterleave_gate_up(g)
|
||||
grads_patched[n] = g
|
||||
loss_patched = out_patched.loss.item()
|
||||
|
||||
# ---------- Compare ----------
|
||||
assert abs(loss_orig - loss_patched) < 0.5, (
|
||||
f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}"
|
||||
)
|
||||
|
||||
# All parameters with gradients in original should have them in patched
|
||||
missing = set(grads_orig.keys()) - set(grads_patched.keys())
|
||||
assert not missing, f"Missing gradients in patched model: {missing}"
|
||||
|
||||
# Compare gradient values
|
||||
# bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,
|
||||
# so use generous tolerance: flag only if both rel >10% AND abs >1e-2
|
||||
mismatches = []
|
||||
for name in grads_orig:
|
||||
if name not in grads_patched:
|
||||
continue
|
||||
g_orig = grads_orig[name]
|
||||
g_patched = grads_patched[name]
|
||||
max_diff = (g_orig - g_patched).abs().max().item()
|
||||
rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)
|
||||
|
||||
if rel_diff > 0.1 and max_diff > 1e-2:
|
||||
mismatches.append(
|
||||
f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}"
|
||||
)
|
||||
|
||||
assert not mismatches, (
|
||||
"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n"
|
||||
+ "\n".join(mismatches)
|
||||
)
|
||||
|
||||
def test_router_weights_receive_gradients(self):
|
||||
"""Verify that router (gate) weights get non-zero gradients."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
|
||||
gate_grads_found = False
|
||||
for name, param in model.named_parameters():
|
||||
if "gate" in name and "weight" in name:
|
||||
gate_grads_found = True
|
||||
assert param.grad is not None, f"No gradient for router: {name}"
|
||||
assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}"
|
||||
|
||||
assert gate_grads_found, "No gate.weight parameters found in model"
|
||||
|
||||
|
||||
class TestSonicMoETrainingConvergence:
|
||||
"""Verify loss decreases during training with SonicMoE."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_loss_decreases(self):
|
||||
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||
losses = []
|
||||
|
||||
for step in range(30):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
|
||||
def test_expert_weights_update(self):
|
||||
"""Verify expert weights change during training (not frozen)."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
|
||||
# Snapshot expert weights before training
|
||||
expert_weights_before = {}
|
||||
for name, param in model.named_parameters():
|
||||
if "experts" in name:
|
||||
expert_weights_before[name] = param.data.clone()
|
||||
|
||||
assert expert_weights_before, "No expert parameters found"
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Check that expert weights changed
|
||||
changed = 0
|
||||
for name, param in model.named_parameters():
|
||||
if name in expert_weights_before:
|
||||
if not torch.equal(param.data, expert_weights_before[name]):
|
||||
changed += 1
|
||||
|
||||
assert changed > 0, "No expert weights changed after 5 training steps"
|
||||
@@ -6,7 +6,7 @@
|
||||
Unit tests for scattermoe-lora code-review fixes.
|
||||
|
||||
Tests cover:
|
||||
- KernelsArgs validator: disable_mlp_kernel_scattermoe
|
||||
- KernelsArgs validator: disable_mlp_kernel
|
||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||
- single2scatter: non-aligned K/N dimensions
|
||||
@@ -20,12 +20,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
# ============================================================================
|
||||
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
|
||||
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestKernelsArgsValidator:
|
||||
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
|
||||
"""Test that disable_mlp_kernel sets both flags correctly.
|
||||
|
||||
These tests call the validator classmethod directly on raw dicts,
|
||||
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
||||
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
|
||||
"use_scattermoe": True,
|
||||
"lora_mlp_kernel": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
assert result["mlp_kernel"] is False
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
|
||||
"use_kernels": True,
|
||||
"use_scattermoe": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["mlp_kernel"] is False
|
||||
# lora_mlp_kernel was not in data, should not be added
|
||||
assert "lora_mlp_kernel" not in result
|
||||
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
|
||||
"use_scattermoe": True,
|
||||
"lora_mlp_kernel": False,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
|
||||
def test_no_change_when_scattermoe_disabled(self):
|
||||
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
|
||||
"use_scattermoe": False,
|
||||
"lora_mlp_kernel": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is True
|
||||
|
||||
|
||||
|
||||
428
tests/integrations/test_sonicmoe.py
Normal file
428
tests/integrations/test_sonicmoe.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Unit tests for the SonicMoE integration."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
ConcatenatedToInterleaved,
|
||||
InterleavedToConcatenated,
|
||||
register_sonicmoe_weight_converter,
|
||||
)
|
||||
|
||||
|
||||
class TestKernelsArgs:
|
||||
def test_mutual_exclusivity_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot use both"):
|
||||
KernelsArgs.model_validate({"use_scattermoe": True, "use_sonicmoe": True})
|
||||
|
||||
def test_sonicmoe_only(self):
|
||||
result = KernelsArgs.model_validate({"use_sonicmoe": True})
|
||||
assert result.use_sonicmoe is True
|
||||
assert result.use_scattermoe is None
|
||||
|
||||
def test_scattermoe_only(self):
|
||||
result = KernelsArgs.model_validate({"use_scattermoe": True})
|
||||
assert result.use_scattermoe is True
|
||||
assert result.use_sonicmoe is None
|
||||
|
||||
def test_neither_set(self):
|
||||
result = KernelsArgs.model_validate({})
|
||||
assert result.use_scattermoe is None
|
||||
assert result.use_sonicmoe is None
|
||||
|
||||
def test_disables_mlp_kernel_when_sonicmoe(self):
|
||||
data = {"use_sonicmoe": True, "lora_mlp_kernel": True}
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
assert result["mlp_kernel"] is False
|
||||
|
||||
|
||||
class TestConcatenatedToInterleaved:
|
||||
@pytest.fixture
|
||||
def sample_tensor(self):
|
||||
"""Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values."""
|
||||
E, I, H = 2, 2, 3 # noqa: E741
|
||||
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
|
||||
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
|
||||
return torch.cat([gate, up], dim=1)
|
||||
|
||||
def test_interleave_rows_alternate(self, sample_tensor):
|
||||
op = ConcatenatedToInterleaved(dim=1)
|
||||
result = op.convert(
|
||||
{"test": sample_tensor},
|
||||
source_patterns=["test"],
|
||||
target_patterns=["test"],
|
||||
)
|
||||
interleaved = result["test"]
|
||||
|
||||
# For expert 0: even rows should be gate, odd rows should be up
|
||||
E, two_I, H = sample_tensor.shape
|
||||
I = two_I // 2 # noqa: E741
|
||||
gate_orig = sample_tensor[:, :I, :]
|
||||
up_orig = sample_tensor[:, I:, :]
|
||||
|
||||
assert torch.equal(interleaved[:, 0::2, :], gate_orig)
|
||||
assert torch.equal(interleaved[:, 1::2, :], up_orig)
|
||||
|
||||
def test_interleave_handles_list_input(self, sample_tensor):
|
||||
op = ConcatenatedToInterleaved(dim=1)
|
||||
result = op.convert(
|
||||
{"test": [sample_tensor]},
|
||||
source_patterns=["test"],
|
||||
target_patterns=["test"],
|
||||
)
|
||||
assert result["test"].shape == sample_tensor.shape
|
||||
|
||||
def test_reverse_op_type(self):
|
||||
op = ConcatenatedToInterleaved(dim=1)
|
||||
assert isinstance(op.reverse_op, InterleavedToConcatenated)
|
||||
assert op.reverse_op.dim == 1
|
||||
|
||||
|
||||
class TestInterleavedToConcatenated:
|
||||
@pytest.fixture
|
||||
def interleaved_tensor(self):
|
||||
"""Create an interleaved tensor [E=2, 2*I=4, H=3]."""
|
||||
E, I, H = 2, 2, 3 # noqa: E741
|
||||
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
|
||||
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
|
||||
interleaved = torch.empty(E, 2 * I, H)
|
||||
interleaved[:, 0::2, :] = gate
|
||||
interleaved[:, 1::2, :] = up
|
||||
return interleaved
|
||||
|
||||
def test_deinterleave_gate_up_separated(self, interleaved_tensor):
|
||||
op = InterleavedToConcatenated(dim=1)
|
||||
result = op.convert(
|
||||
{"test": interleaved_tensor},
|
||||
source_patterns=["test"],
|
||||
target_patterns=["test"],
|
||||
)
|
||||
concatenated = result["test"]
|
||||
|
||||
E, two_I, H = concatenated.shape
|
||||
I = two_I // 2 # noqa: E741
|
||||
|
||||
# First half should be gate (even rows from interleaved)
|
||||
assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :])
|
||||
# Second half should be up (odd rows from interleaved)
|
||||
assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :])
|
||||
|
||||
def test_reverse_op_type(self):
|
||||
op = InterleavedToConcatenated(dim=1)
|
||||
assert isinstance(op.reverse_op, ConcatenatedToInterleaved)
|
||||
assert op.reverse_op.dim == 1
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
@pytest.fixture
|
||||
def concat_tensor(self):
|
||||
E, I, H = 4, 8, 16 # noqa: E741
|
||||
gate = torch.randn(E, I, H)
|
||||
up = torch.randn(E, I, H)
|
||||
return torch.cat([gate, up], dim=1)
|
||||
|
||||
def test_interleave_then_deinterleave_is_identity(self, concat_tensor):
|
||||
fwd = ConcatenatedToInterleaved(dim=1)
|
||||
rev = InterleavedToConcatenated(dim=1)
|
||||
|
||||
interleaved = fwd.convert(
|
||||
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
|
||||
)["k"]
|
||||
recovered = rev.convert(
|
||||
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
|
||||
)["k"]
|
||||
|
||||
assert torch.equal(concat_tensor, recovered)
|
||||
|
||||
def test_reverse_op_chain_is_identity(self, concat_tensor):
|
||||
"""Verify that op.reverse_op produces an exact inverse."""
|
||||
op = ConcatenatedToInterleaved(dim=1)
|
||||
rev = op.reverse_op
|
||||
|
||||
interleaved = op.convert(
|
||||
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
|
||||
)["k"]
|
||||
recovered = rev.convert(
|
||||
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
|
||||
)["k"]
|
||||
|
||||
assert torch.equal(concat_tensor, recovered)
|
||||
|
||||
def test_various_shapes(self):
|
||||
"""Test with different expert counts and dimensions."""
|
||||
fwd = ConcatenatedToInterleaved(dim=1)
|
||||
rev = InterleavedToConcatenated(dim=1)
|
||||
|
||||
for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741
|
||||
concat = torch.randn(E, 2 * I, H)
|
||||
interleaved = fwd.convert(
|
||||
{"k": concat}, source_patterns=["k"], target_patterns=["k"]
|
||||
)["k"]
|
||||
recovered = rev.convert(
|
||||
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
|
||||
)["k"]
|
||||
assert torch.equal(concat, recovered), (
|
||||
f"Failed for shape ({E}, {2 * I}, {H})"
|
||||
)
|
||||
|
||||
|
||||
class TestWeightConverterRegistration:
|
||||
def test_register_appends_interleave_op(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
register_sonicmoe_weight_converter("qwen3_moe")
|
||||
|
||||
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||
# Find the gate_up_proj converter
|
||||
gate_up_converter = None
|
||||
for conv in modified:
|
||||
if hasattr(conv, "operations") and any(
|
||||
"gate_up_proj" in pat for pat in conv.target_patterns
|
||||
):
|
||||
gate_up_converter = conv
|
||||
break
|
||||
|
||||
assert gate_up_converter is not None
|
||||
assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved)
|
||||
|
||||
def test_double_registration_is_idempotent(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
register_sonicmoe_weight_converter("qwen3_moe")
|
||||
register_sonicmoe_weight_converter("qwen3_moe")
|
||||
|
||||
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||
for conv in modified:
|
||||
if hasattr(conv, "operations") and any(
|
||||
"gate_up_proj" in pat for pat in conv.target_patterns
|
||||
):
|
||||
interleave_count = sum(
|
||||
isinstance(op, ConcatenatedToInterleaved) for op in conv.operations
|
||||
)
|
||||
assert interleave_count == 1, (
|
||||
f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}"
|
||||
)
|
||||
break
|
||||
|
||||
def test_register_unsupported_model_type_warns(self):
|
||||
# A model type with no conversion mapping should warn but not raise
|
||||
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
|
||||
|
||||
|
||||
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
|
||||
"""Create a mock qwen-style MoE block for routing tests."""
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
return SimpleNamespace(gate=gate), T, H, E, K
|
||||
|
||||
|
||||
def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1):
|
||||
"""Create a mock GLM5-style MoE block for routing tests."""
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4):
|
||||
"""Create a mock minimax_m2-style MoE block for routing tests.
|
||||
|
||||
minimax_m2 uses sigmoid->topk WITHOUT group selection:
|
||||
- e_score_correction_bias is on the moe_block (not on gate)
|
||||
- No n_group / topk_group attributes
|
||||
- Always normalizes (norm_topk_prob defaults to True)
|
||||
- No routed_scaling_factor (defaults to 1.0)
|
||||
"""
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxTopkRouting:
|
||||
def test_output_shapes(self):
|
||||
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
# Token indices must be sorted ascending (SonicMoE requirement)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||
per_token_sums = scores.reshape(T, K).sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
|
||||
|
||||
|
||||
class TestSigmoidTopkRouting:
|
||||
def test_output_shapes(self):
|
||||
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_scores_are_nonnegative(self):
|
||||
"""Sigmoid outputs are in [0, 1], so scores should be non-negative."""
|
||||
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
assert (scores >= 0).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
# Get scores with scaling_factor=1.0
|
||||
scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
# Get scores with scaling_factor=2.0
|
||||
moe_block.routed_scaling_factor = 2.0
|
||||
scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5)
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With n_group=4 and topk_group=1, only 1/4 of experts should be selectable."""
|
||||
moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
# Each token's experts should all fall within a single group (size E//n_group=4)
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
for t in range(T):
|
||||
experts = expert_idx_2d[t]
|
||||
groups = experts // (E // moe_block.n_group)
|
||||
# All selected experts should be from the same group
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
|
||||
class TestMiniMaxM2SigmoidRouting:
|
||||
"""Tests for minimax_m2 routing: sigmoid->topk without group selection."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
"""Validates getattr defaults work: n_group=1, E from gate.weight.shape[0]."""
|
||||
moe_block, T, H, E, K = _make_minimax_m2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_bias_on_block_not_gate(self):
|
||||
"""Verify that e_score_correction_bias on the block (not gate) is used."""
|
||||
T, H, E, K = 8, 16, 8, 2
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
# Large positive bias on expert 0 should make it selected more often
|
||||
bias = torch.zeros(E)
|
||||
bias[0] = 100.0
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=bias,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
# Expert 0 should appear for every token due to the large bias
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
for t in range(T):
|
||||
assert 0 in expert_idx_2d[t]
|
||||
158
tests/integrations/test_sonicmoe_gradients.py
Normal file
158
tests/integrations/test_sonicmoe_gradients.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Gradient correctness tests for SonicMoE routing functions (CPU-only).
|
||||
|
||||
Uses torch.autograd.gradcheck with float32 inputs to match the production
|
||||
code path where routing happens in float32.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
_GC_EPS = 1e-3
|
||||
_GC_ATOL = 1e-3
|
||||
_GC_RTOL = 1e-3
|
||||
|
||||
|
||||
def _make_softmax_moe_block(weight):
|
||||
gate = torch.nn.Module()
|
||||
gate.weight = weight
|
||||
gate.top_k = 2
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = torch.nn.Module()
|
||||
moe_block.gate = gate
|
||||
return moe_block
|
||||
|
||||
|
||||
def _make_sigmoid_moe_block(weight, bias):
|
||||
gate = torch.nn.Module()
|
||||
gate.weight = weight
|
||||
gate.e_score_correction_bias = bias
|
||||
|
||||
moe_block = torch.nn.Module()
|
||||
moe_block.gate = gate
|
||||
moe_block.top_k = 2
|
||||
moe_block.n_routed_experts = weight.shape[0]
|
||||
moe_block.n_group = 1
|
||||
moe_block.norm_topk_prob = True
|
||||
moe_block.routed_scaling_factor = 1.0
|
||||
return moe_block
|
||||
|
||||
|
||||
class TestSoftmaxTopkRoutingGradcheck:
|
||||
"""Numerical gradient verification for softmax_topk_routing."""
|
||||
|
||||
def test_gradcheck_wrt_gate_weight(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||
|
||||
def fn(weight):
|
||||
moe_block = _make_softmax_moe_block(weight)
|
||||
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||
return scores
|
||||
|
||||
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(
|
||||
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||
)
|
||||
|
||||
def test_gradcheck_wrt_hidden_states(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
weight = torch.randn(E, H, dtype=torch.float32)
|
||||
moe_block = _make_softmax_moe_block(weight)
|
||||
|
||||
def fn(hidden):
|
||||
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||
return scores
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(
|
||||
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||
)
|
||||
|
||||
def test_gradcheck_wrt_router_logits(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||
|
||||
def fn(weight):
|
||||
moe_block = _make_softmax_moe_block(weight)
|
||||
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
|
||||
return router_logits
|
||||
|
||||
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(
|
||||
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||
)
|
||||
|
||||
def test_no_norm_variant(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||
|
||||
def fn(weight):
|
||||
moe_block = _make_softmax_moe_block(weight)
|
||||
moe_block.gate.norm_topk_prob = False
|
||||
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||
return scores
|
||||
|
||||
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(
|
||||
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||
)
|
||||
|
||||
|
||||
class TestSigmoidTopkRoutingGradcheck:
|
||||
"""Numerical gradient verification for sigmoid_topk_routing."""
|
||||
|
||||
def test_gradcheck_wrt_gate_weight(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||
bias = torch.zeros(E, dtype=torch.float32)
|
||||
|
||||
def fn(weight):
|
||||
moe_block = _make_sigmoid_moe_block(weight, bias)
|
||||
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
return scores
|
||||
|
||||
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(
|
||||
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||
)
|
||||
|
||||
def test_gradcheck_wrt_hidden_states(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
weight = torch.randn(E, H, dtype=torch.float32)
|
||||
bias = torch.zeros(E, dtype=torch.float32)
|
||||
moe_block = _make_sigmoid_moe_block(weight, bias)
|
||||
|
||||
def fn(hidden):
|
||||
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
return scores
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(
|
||||
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||
)
|
||||
|
||||
def test_gradcheck_wrt_bias(self):
|
||||
T, H, E = 4, 8, 4
|
||||
|
||||
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||
weight = torch.randn(E, H, dtype=torch.float32)
|
||||
|
||||
def fn(bias):
|
||||
moe_block = _make_sigmoid_moe_block(weight, bias)
|
||||
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
return scores
|
||||
|
||||
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
|
||||
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)
|
||||
Reference in New Issue
Block a user