From 6a8baf8fa7c76651edb8a509c505fa19ea07599b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 6 Mar 2026 01:43:31 +0700 Subject: [PATCH] feat: add sonicmoe (#3411) * feat: add sonicmoe * feat: add torch compile for routing * feat: add routing smoke test * feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe * fix: disable mlp kernel for sonicmoe too * feat: update to sonicmoe release * chore: update import following new sonicmoe changes * feat: update handling for blackwell * feat: add sonicmoe e2e test * fix: installation for updated sonicmoe * fix: git commit * fix: ignore py req and fix metadata * fix: increase min hidden size to match sonicmoe kernel min * fix: attempt properly interleave and handle unpatch mid-test * chore: refactor teardown better * chore: refactor to re-use rearrange * fix: add idempotency guard * fix: address comments on CI memory and interleave * fix: tests grad, param doublewrapped --- src/axolotl/integrations/kernels/README.md | 42 +- src/axolotl/integrations/kernels/args.py | 19 +- src/axolotl/integrations/kernels/constants.py | 68 +++ src/axolotl/integrations/kernels/plugin.py | 107 +++-- .../integrations/kernels/sonicmoe/__init__.py | 3 + .../integrations/kernels/sonicmoe/patch.py | 213 +++++++++ .../integrations/kernels/sonicmoe/routing.py | 219 +++++++++ .../kernels/sonicmoe/weight_converter.py | 181 ++++++++ tests/e2e/integrations/test_sonicmoe.py | 288 ++++++++++++ tests/integrations/test_scattermoe_lora.py | 14 +- tests/integrations/test_sonicmoe.py | 428 ++++++++++++++++++ tests/integrations/test_sonicmoe_gradients.py | 158 +++++++ 12 files changed, 1698 insertions(+), 42 deletions(-) create mode 100644 src/axolotl/integrations/kernels/constants.py create mode 100644 src/axolotl/integrations/kernels/sonicmoe/__init__.py create mode 100644 src/axolotl/integrations/kernels/sonicmoe/patch.py create mode 100644 src/axolotl/integrations/kernels/sonicmoe/routing.py create mode 100644 src/axolotl/integrations/kernels/sonicmoe/weight_converter.py create mode 100644 tests/e2e/integrations/test_sonicmoe.py create mode 100644 tests/integrations/test_sonicmoe.py create mode 100644 tests/integrations/test_sonicmoe_gradients.py diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 237d653cf..7c40720af 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface): } ``` -In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`. +In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`. ## Usage @@ -21,23 +21,55 @@ plugins: - axolotl.integrations.kernels.KernelsPlugin use_kernels: true + +# Choose one (mutually exclusive): use_scattermoe: true +# OR +use_sonicmoe: true ``` -**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`. +**Important:** Setting `experts_implementation` is incompatible with custom kernel options. + +### SonicMoE installation + +**Prerequisites:** +- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU +- CUDA 12.9+ (13.0+ for B300) +- PyTorch 2.7+ (2.9.1 recommended) +- For B300: Triton 3.6.0 + +```bash +pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5 +``` + +See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details. + +**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels. ## How It Works The `KernelsPlugin` runs before model loading and: -1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo. +### ScatterMoE +1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels). 2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation. -This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.). +### SonicMoE +1. Resolves the model's MoE block class(es) from `constants.py`. +2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format. +3. Supports both softmax->topk and sigmoid->topk routing strategies. + +Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution. + +#### Supported Models + +See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.). ## Limitations -ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA). +ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment. + +SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures. ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm. diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index e8cf7208a..d9b261e72 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -6,7 +6,18 @@ LOG = get_logger(__name__) class KernelsArgs(BaseModel): - use_scattermoe: bool | None = True + use_scattermoe: bool | None = None + use_sonicmoe: bool | None = None + + @model_validator(mode="before") + @classmethod + def check_mutually_exclusive(cls, data): + if data.get("use_scattermoe") and data.get("use_sonicmoe"): + raise ValueError( + "Cannot use both ScatterMoE and SonicMoE simultaneously. " + "Please set only one of `use_scattermoe` or `use_sonicmoe` to true." + ) + return data @model_validator(mode="before") @classmethod @@ -36,11 +47,11 @@ class KernelsArgs(BaseModel): @model_validator(mode="before") @classmethod - def disable_mlp_kernel_scattermoe(cls, data): - if data.get("use_scattermoe") is True: + def disable_mlp_kernel(cls, data): + if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True: if data.get("lora_mlp_kernel") is True: LOG.warning( - "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues." + "Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues." ) data["lora_mlp_kernel"] = False data["mlp_kernel"] = False diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py new file mode 100644 index 000000000..529ed4ad6 --- /dev/null +++ b/src/axolotl/integrations/kernels/constants.py @@ -0,0 +1,68 @@ +""" +Supported MoE block mappings for kernel integrations. + +Maps model_type to the SparseMoeBlock class name(s) in transformers. +Used by both ScatterMoE and SonicMoE kernel paths. + +Values can be a single class name (str) or a list of class names for models +with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker). +""" + +import importlib + +SPARSE_MOE_BLOCK = { + # softmax -> topk routing + "qwen2_moe": "Qwen2MoeSparseMoeBlock", + "qwen3_moe": "Qwen3MoeSparseMoeBlock", + "qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock", + "qwen3_next": "Qwen3NextSparseMoeBlock", + "qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock", + # qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate) + "qwen3_omni_moe": [ + "Qwen3OmniMoeThinkerTextSparseMoeBlock", + "Qwen3OmniMoeTalkerTextSparseMoeBlock", + ], + "olmoe": "OlmoeSparseMoeBlock", + "mixtral": "MixtralSparseMoeBlock", + "minimax": "MiniMaxSparseMoeBlock", + # sigmoid -> topk routing (with group-based expert selection) + "glm_moe_dsa": "GlmMoeDsaMoE", + "deepseek_v3": "DeepseekV3MoE", + "glm4_moe": "Glm4MoeMoE", + "glm4_moe_lite": "Glm4MoeLiteMoE", + "glm4v_moe": "Glm4vMoeTextMoE", + # sigmoid -> topk routing (no group selection) + "minimax_m2": "MiniMaxM2SparseMoeBlock", + # Models below need custom routing (not yet implemented): + # "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk + # "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group) + # "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing + # "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases +} + + +def resolve_moe_block_classes(model_type: str): + """Resolve all MoE block classes from transformers for the given model type. + + Returns a list of classes (one for most models, multiple for models with + distinct MoE block types like qwen3_omni_moe). + """ + entry = SPARSE_MOE_BLOCK.get(model_type) + if entry is None: + raise ValueError( + f"Unsupported MoE model type '{model_type}'. " + f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}" + ) + + cls_names = entry if isinstance(entry, list) else [entry] + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + module = importlib.import_module(module_path) + + classes = [] + for cls_name in cls_names: + moe_cls = getattr(module, cls_name, None) + if moe_cls is None: + raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'") + classes.append(moe_cls) + + return classes diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index 56d0448d5..ad14dd148 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,14 +1,59 @@ +import importlib +import os from pathlib import Path -from kernels import ( - LocalLayerRepository, - Mode, - register_kernel_mapping, - replace_kernel_forward_from_hub, -) +import torch from axolotl.integrations.base import BasePlugin -from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _check_sonicmoe_gpu_compat(): + """Validate GPU compute capability for SonicMoE and configure env. + + Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103). + B300 (sm_103) additionally requires Triton 3.6.0. + """ + if not torch.cuda.is_available(): + return + + cc = torch.cuda.get_device_capability() + + if cc < (9, 0): + raise RuntimeError( + f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, " + f"but detected sm_{cc[0]}{cc[1]}." + ) + + if cc > (10, 3): + raise RuntimeError( + f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. " + f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)." + ) + + # Blackwell (sm_100+): enable QuACK GEMM kernels + if cc >= (10, 0): + os.environ.setdefault("USE_QUACK_GEMM", "1") + LOG.info( + f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1" + ) + + # B300 (sm_103): requires Triton 3.6.0 + if cc == (10, 3): + triton_spec = importlib.util.find_spec("triton") + if triton_spec is None: + raise RuntimeError( + "B300 (sm_103) requires Triton 3.6.0, but Triton is not installed." + ) + import triton + + triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2]) + if triton_version != (3, 6): + raise RuntimeError( + f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}." + ) class KernelsPlugin(BasePlugin): @@ -19,8 +64,32 @@ class KernelsPlugin(BasePlugin): if cfg.use_scattermoe: self._register_kernels() self._kernelize_model(cfg.model_config_type) + elif cfg.use_sonicmoe: + if not importlib.util.find_spec("sonicmoe"): + raise RuntimeError( + "SonicMoE is not installed. See installation instructions at " + "https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation" + ) + + _check_sonicmoe_gpu_compat() + + from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe + + LOG.info( + f"Applying SonicMoE patches for model type: {cfg.model_config_type}" + ) + patch_sonicmoe( + cfg.model_config_type, + torch_compile=bool(getattr(cfg, "torch_compile", False)), + ) def _register_kernels(self): + from kernels import ( + LocalLayerRepository, + Mode, + register_kernel_mapping, + ) + plugin_root = Path(__file__).parent register_kernel_mapping( { @@ -42,25 +111,11 @@ class KernelsPlugin(BasePlugin): ) def _kernelize_model(self, model_type: str): - if model_type == "olmoe": - from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock + from kernels import replace_kernel_forward_from_hub + from axolotl.integrations.kernels.constants import resolve_moe_block_classes + + for model_moe_cls in resolve_moe_block_classes(model_type): replace_kernel_forward_from_hub( - OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts" + model_moe_cls, "HFScatterMoEParallelExperts" ) - else: - try: - model_moe_cls = get_model_moe_block(model_type) - replace_kernel_forward_from_hub( - model_moe_cls, "HFScatterMoEParallelExperts" - ) - except Exception as err: - raise ValueError(f"Unsupported model type: {model_type}") from err - - -def get_model_moe_block(model_type: str): - module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) - module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"]) - model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock") - return model_cls diff --git a/src/axolotl/integrations/kernels/sonicmoe/__init__.py b/src/axolotl/integrations/kernels/sonicmoe/__init__.py new file mode 100644 index 000000000..d1f5e5f60 --- /dev/null +++ b/src/axolotl/integrations/kernels/sonicmoe/__init__.py @@ -0,0 +1,3 @@ +from .patch import patch_sonicmoe + +__all__ = ["patch_sonicmoe"] diff --git a/src/axolotl/integrations/kernels/sonicmoe/patch.py b/src/axolotl/integrations/kernels/sonicmoe/patch.py new file mode 100644 index 000000000..a3b96f12a --- /dev/null +++ b/src/axolotl/integrations/kernels/sonicmoe/patch.py @@ -0,0 +1,213 @@ +""" +SonicMoE patching for SparseMoeBlock forward pass. + +Monkeypatches the SparseMoeBlock class for a given model type to use +SonicMoE's optimized kernels. Two forward paths are supported: + +1. **General routing path** (routing_fn is not None): + Uses a custom routing function + ``moe_general_routing_inputs``. + Suitable for models with non-standard routing (softmax->topk, sigmoid->topk). + +2. **Fused topk->softmax path** (routing_fn is None): + Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation. + Suitable for models with simple topk->softmax routing. + +Weight format conversion (interleave/deinterleave) is handled by the +WeightConverter system, so the forward assumes weights are already in +interleaved format. + +Shared experts are handled generically: if the block has a ``shared_expert`` +or ``shared_experts`` attribute, its output is computed alongside the routed +experts and added to the final output. An optional ``shared_expert_gate`` +applies sigmoid gating to the shared expert contribution. +""" + +import torch +import torch.nn.functional as F + +from axolotl.integrations.kernels.constants import resolve_moe_block_classes +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def patch_sonicmoe(model_type: str, torch_compile: bool = False): + """Main entry point: patch SparseMoeBlock for SonicMoE support. + + Args: + model_type: The HuggingFace model type (e.g. "qwen3_moe"). + torch_compile: If True, wrap routing functions with torch.compile + for kernel fusion (fuses softmax+topk+renorm into fewer launches). + """ + from .routing import get_model_moe_config + from .weight_converter import register_sonicmoe_weight_converter + + routing_fn, activation, router_attr = get_model_moe_config(model_type) + + if torch_compile and routing_fn is not None: + routing_fn = _try_compile_routing(routing_fn) + + for moe_cls in resolve_moe_block_classes(model_type): + _patch_forward(moe_cls, routing_fn, activation, router_attr) + register_sonicmoe_weight_converter(model_type) + + +def _try_compile_routing(routing_fn): + """Attempt to torch.compile the routing function, fall back to eager on failure.""" + try: + compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False) + LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}") + return compiled_fn + except Exception as exc: # pylint: disable=broad-except + LOG.warning( + f"torch.compile failed for routing function {routing_fn.__name__}, " + f"falling back to eager: {exc}" + ) + return routing_fn + + +def _patch_forward(moe_cls, routing_fn, activation, router_attr): + """Monkeypatch the SparseMoeBlock class with a SonicMoE forward. + + The patched forward handles shared experts generically: if + ``self.shared_expert`` or ``self.shared_experts`` exists, it is computed + and added to the routed output. If ``self.shared_expert_gate`` also exists, + it applies sigmoid gating to the shared expert contribution (as in qwen2_moe). + + Args: + moe_cls: The SparseMoeBlock class to patch. + routing_fn: Routing function (e.g. softmax_topk_routing), or None + for the fused moe_TC_softmax_topk_layer path. + activation: SonicMoE ActivationType enum value. + router_attr: Name of the router module attribute on the MoE block. + """ + if hasattr(moe_cls, "_original_forward"): + LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping") + return + + original_forward = moe_cls.forward + + if routing_fn is not None: + _make_general_forward(moe_cls, routing_fn, activation) + else: + _make_fused_forward(moe_cls, activation, router_attr) + + moe_cls._original_forward = original_forward + LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation") + + +def _make_general_forward(moe_cls, routing_fn, activation): + """Create forward using routing_fn + moe_general_routing_inputs.""" + + def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + from sonicmoe import moe_general_routing_inputs + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Shared expert (computed early, matching original model ordering) + shared_expert_output = _compute_shared_expert(self, hidden_states_flat) + + # Routing + router_scores, token_indices, expert_indices, _router_logits = routing_fn( + hidden_states_flat, self + ) + + # Permute weights to SonicMoE layout: + # gate_up: [E, 2*I, H] -> [2*I, H, E] + # down: [E, H, I] -> [H, I, E] + gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) + down_weight = self.experts.down_proj.permute(1, 2, 0) + E = gate_up_weight.shape[-1] + + output, _ = moe_general_routing_inputs( + hidden_states_flat, + router_scores, + token_indices, + expert_indices, + gate_up_weight, + None, # b1 (no gate/up bias) + down_weight, + None, # b2 (no down bias) + E, + torch.cuda.current_stream().cuda_stream, + activation, + False, # is_inference_mode + ) + + # Add shared expert contribution if present + if shared_expert_output is not None: + if hasattr(self, "shared_expert_gate"): + shared_expert_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states_flat)) + * shared_expert_output + ) + output = output + shared_expert_output + + return output.view(batch_size, sequence_length, hidden_dim) + + moe_cls.forward = sonicmoe_forward + + +def _make_fused_forward(moe_cls, activation, router_attr): + """Create forward using moe_TC_softmax_topk_layer (topk -> softmax).""" + + def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + from sonicmoe import moe_TC_softmax_topk_layer + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Shared expert (computed early, matching original model ordering) + shared_expert_output = _compute_shared_expert(self, hidden_states_flat) + + router = getattr(self, router_attr) + + # Permute weights to SonicMoE layout: + # gate_up: [E, 2*I, H] -> [2*I, H, E] + # down: [E, H, I] -> [H, I, E] + gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) + down_weight = self.experts.down_proj.permute(1, 2, 0) + + output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer( + hidden_states_flat, + router.weight, + gate_up_weight, + None, # b1 (no gate/up bias) + down_weight, + None, # b2 (no down bias) + router.top_k, + torch.cuda.current_stream().cuda_stream, + activation, + False, # is_inference_mode + ) + + # Add shared expert contribution if present + if shared_expert_output is not None: + if hasattr(self, "shared_expert_gate"): + shared_expert_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states_flat)) + * shared_expert_output + ) + output = output + shared_expert_output + + return output.view(batch_size, sequence_length, hidden_dim) + + moe_cls.forward = sonicmoe_fused_forward + + +def _compute_shared_expert(moe_block, hidden_states_flat): + """Compute shared expert output if the block has one. + + Handles singular (qwen2_moe: ``shared_expert``), plural + (glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP + (hunyuan_v1_moe: ``shared_mlp``) attribute names. + """ + shared_expert = ( + getattr(moe_block, "shared_expert", None) + or getattr(moe_block, "shared_experts", None) + or getattr(moe_block, "shared_mlp", None) + ) + if shared_expert is not None: + return shared_expert(hidden_states_flat) + return None diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/sonicmoe/routing.py new file mode 100644 index 000000000..3f93c1596 --- /dev/null +++ b/src/axolotl/integrations/kernels/sonicmoe/routing.py @@ -0,0 +1,219 @@ +""" +Routing functions for SonicMoE integration. + +Different MoE architectures use different routing strategies: +- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization) +- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) +- glm_moe_dsa: sigmoid -> topk (with group-based expert selection) + +Each model type maps to a (routing_fn, activation_type, router_attr) triple. +When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. +""" + +import torch +import torch.nn.functional as F + + +def get_model_moe_config(model_type: str): + """Returns (routing_fn, activation, router_attr) for a given model type. + + Args: + model_type: HuggingFace model type string. + + Returns: + routing_fn: Callable or None. None signals the fused + moe_TC_softmax_topk_layer path (topk -> softmax models). + activation: SonicMoE ActivationType enum value. + router_attr: Name of the router module attribute on the MoE block + (e.g. "gate" or "router"). + + The activation type cannot be derived from config.hidden_act because + e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU + (act_fn(gate) * up pattern). So we specify it per model type. + """ + from sonicmoe.enums import ActivationType + + if model_type in ( + "qwen2_moe", + "qwen3_moe", + "qwen3_5_moe", + "qwen3_next", + "qwen3_vl_moe", + "qwen3_omni_moe", + "olmoe", + "mixtral", + "minimax", + ): + return softmax_topk_routing, ActivationType.SWIGLU, "gate" + elif model_type in ( + "glm_moe_dsa", + "deepseek_v3", + "glm4_moe", + "glm4_moe_lite", + "glm4v_moe", + "minimax_m2", + ): + return sigmoid_topk_routing, ActivationType.SWIGLU, "gate" + # elif model_type in ("ernie4_5_moe",): + # # Softmax→topk with e_score_correction_bias applied between softmax and topk. + # return ..., ActivationType.SWIGLU, "gate" + # elif model_type in ("deepseek_v2",): + # # Softmax→topk with group_limited_greedy. Different attr names: num_group + # # (not n_group), gate is nn.Linear (not a router class). + # return ..., ActivationType.SWIGLU, "gate" + # elif model_type in ("hunyuan_v1_moe",): + # # Softmax→topk but gate structure differs: gate.wg (not gate.weight), + # # top_k on block not gate, creates scatter routing matrix. + # return ..., ActivationType.SWIGLU, "gate" + # Fused topk -> softmax path (routing_fn=None): + # elif model_type in ("gpt_oss",): + # # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer + # # ignores (it only takes router_w, not bias). Also has transposed + # # weight layout [E, H, 2*I] and custom GLU activation. + # return None, ActivationType.SWIGLU, "router" + else: + raise ValueError(f"SonicMoE: unsupported model type '{model_type}'") + + +def softmax_topk_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm. + + Args: + hidden_states: [T, H] flattened token representations + moe_block: MoE block module (accesses moe_block.gate.*) + + Returns: + router_scores: [T*K] flattened scores (float32) + token_indices: [T*K] which token each entry belongs to (int32), sorted ascending + expert_indices: [T*K] which expert (int32) + router_logits: [T, E] original logits for aux loss + """ + gate = moe_block.gate + T, H = hidden_states.shape + K = gate.top_k + + # Compute router logits and softmax over all experts + router_logits = F.linear(hidden_states, gate.weight) # [T, E] + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + + # Select top-k experts per token + top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each + + # Renormalize if configured (default True for models without the attribute, + # e.g. Mixtral/MiniMax which always normalize) + if getattr(gate, "norm_topk_prob", True): + top_values = top_values / top_values.sum(dim=-1, keepdim=True) + + # no-op: matches transformers which casts to softmax output dtype (float32). + # top_values = top_values.to(router_probs.dtype) + + # Flatten for moe_general_routing_inputs. + # Token indices are naturally sorted ascending from the [T, K] layout: + # [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE. + # Expert sorting is handled internally by general_routing_router_metadata. + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = top_values.reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + +def sigmoid_topk_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Sigmoid-based routing: sigmoid -> optional group selection -> topk. + + Supports two variants: + - **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1, + bias on gate, group-based masking before topk. + - **No group selection** (minimax_m2): n_group == 1 (or absent), + bias on moe_block, straight topk from all experts. + + Final routing weights come from the original sigmoid scores (not + bias-corrected), with optional renormalization and scaling. + + Args: + hidden_states: [T, H] flattened token representations + moe_block: MoE block module (accesses moe_block.gate.* and + optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob, + .routed_scaling_factor, .n_routed_experts) + + Returns: + router_scores: [T*K] flattened scores (float32) + token_indices: [T*K] which token each entry belongs to (int32), sorted ascending + expert_indices: [T*K] which expert (int32) + router_logits: [T, E] original logits for aux loss + """ + gate = moe_block.gate + T, H = hidden_states.shape + K = moe_block.top_k + E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0]) + n_group = getattr(moe_block, "n_group", 1) + + # Compute router logits and sigmoid probabilities + router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E] + router_probs = router_logits.sigmoid() # [T, E] + + # Bias-corrected scores for expert selection (not used for final weights). + # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block. + e_score_correction_bias = getattr(gate, "e_score_correction_bias", None) + if e_score_correction_bias is None: + e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None) + if e_score_correction_bias is None: + raise AttributeError( + f"sigmoid_topk_routing requires e_score_correction_bias on " + f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it" + ) + scores_for_choice = router_probs + e_score_correction_bias + + # Group-based selection: pick top groups, mask the rest (skip when n_group == 1) + if n_group > 1: + group_scores = ( + scores_for_choice.view(-1, n_group, E // n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [T, n_group] + group_idx = torch.topk( + group_scores, k=moe_block.topk_group, dim=-1, sorted=False + )[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + + # Final topk from (possibly masked) scores + topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] + + # Gather weights from original sigmoid scores (not bias-corrected) + topk_weights = router_probs.gather(1, topk_indices) + + # Optional renormalization + scaling + norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) + if norm_topk_prob: + topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) + routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) + topk_weights = topk_weights * routed_scaling_factor + + # Flatten for moe_general_routing_inputs. + # Token indices are naturally sorted ascending from the [T, K] layout. + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits diff --git a/src/axolotl/integrations/kernels/sonicmoe/weight_converter.py b/src/axolotl/integrations/kernels/sonicmoe/weight_converter.py new file mode 100644 index 000000000..172864ac6 --- /dev/null +++ b/src/axolotl/integrations/kernels/sonicmoe/weight_converter.py @@ -0,0 +1,181 @@ +""" +Custom WeightConverter operations for SonicMoE weight format conversion. + +SonicMoE requires gate_up_proj weights in interleaved format: +- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up +- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...] + +These ConversionOps integrate with transformers' WeightConverter system so that +weights are transparently converted during loading and reverted during saving. +""" + +from typing import Any + +import torch +from einops import rearrange +from transformers.core_model_loading import ConversionOps + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor: + """[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension.""" + return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2) + + +def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor: + """[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension.""" + return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2) + + +class ConcatenatedToInterleaved(ConversionOps): + """Convert concatenated gate/up projections to interleaved format. + + Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H] + Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...] + + This operation is applied along ``dim`` (default 1, the 2*I dimension). + """ + + def __init__(self, dim: int = 1): + self.dim = dim + + @torch.no_grad() + def convert( + self, + input_dict: dict[str, Any], + source_patterns: list[str], + target_patterns: list[str], + **kwargs, + ) -> dict[str, torch.Tensor]: + target_pattern = self._get_target_pattern( + input_dict, source_patterns, target_patterns + ) + tensors = next(iter(input_dict.values())) + tensor = tensors[0] if isinstance(tensors, list) else tensors + + interleaved = interleave_gate_up(tensor) + + return {target_pattern: interleaved} + + def _get_target_pattern( + self, + input_dict: dict[str, Any], + source_patterns: list[str], + target_patterns: list[str], + ) -> str: + # Follow the same logic as Transpose.get_target_pattern + if len(input_dict) != 1: + raise ValueError("Undefined Operation encountered!") + if len(target_patterns) > 1: + if len(source_patterns) == 1: + return source_patterns[0] + raise ValueError("Undefined Operation encountered!") + return target_patterns[0] + + @property + def reverse_op(self) -> ConversionOps: + return InterleavedToConcatenated(self.dim) + + +class InterleavedToConcatenated(ConversionOps): + """Convert interleaved gate/up projections back to concatenated format. + + Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...] + Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H] + + This is the reverse of ``ConcatenatedToInterleaved``. + """ + + def __init__(self, dim: int = 1): + self.dim = dim + + @torch.no_grad() + def convert( + self, + input_dict: dict[str, Any], + source_patterns: list[str], + target_patterns: list[str], + **kwargs, + ) -> dict[str, torch.Tensor]: + target_pattern = self._get_target_pattern( + input_dict, source_patterns, target_patterns + ) + tensors = next(iter(input_dict.values())) + tensor = tensors[0] if isinstance(tensors, list) else tensors + + concatenated = deinterleave_gate_up(tensor) + + return {target_pattern: concatenated} + + def _get_target_pattern( + self, + input_dict: dict[str, Any], + source_patterns: list[str], + target_patterns: list[str], + ) -> str: + if len(input_dict) != 1: + raise ValueError("Undefined Operation encountered!") + if len(target_patterns) > 1: + if len(source_patterns) == 1: + return source_patterns[0] + raise ValueError("Undefined Operation encountered!") + return target_patterns[0] + + @property + def reverse_op(self) -> ConversionOps: + return ConcatenatedToInterleaved(self.dim) + + +def register_sonicmoe_weight_converter(model_type: str): + """Override the conversion mapping to add interleave step for gate_up_proj. + + Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj + converter chain. For example, qwen3_moe's chain becomes: + MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1) + + The reverse is auto-generated for saving: + InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0) + """ + from transformers.conversion_mapping import ( + get_checkpoint_conversion_mapping, + register_checkpoint_conversion_mapping, + ) + + existing = get_checkpoint_conversion_mapping(model_type) + if existing is None: + LOG.warning( + f"No conversion mapping found for model type '{model_type}'. " + "SonicMoE weight interleaving will not be applied during checkpoint loading." + ) + return + + # Find the gate_up_proj converter and append ConcatenatedToInterleaved + patched = False + for converter in existing: + if hasattr(converter, "operations") and any( + "gate_up_proj" in pat for pat in converter.target_patterns + ): + # Guard against double registration (e.g. plugin reloaded) + if any( + isinstance(op, ConcatenatedToInterleaved) for op in converter.operations + ): + LOG.info( + f"SonicMoE weight converter already registered for '{model_type}'" + ) + return + converter.operations.append(ConcatenatedToInterleaved(dim=1)) + patched = True + break + + if not patched: + LOG.warning( + f"Could not find gate_up_proj converter for model type '{model_type}'. " + "SonicMoE weight interleaving will not be applied during checkpoint loading." + ) + return + + register_checkpoint_conversion_mapping(model_type, existing, overwrite=True) + LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'") diff --git a/tests/e2e/integrations/test_sonicmoe.py b/tests/e2e/integrations/test_sonicmoe.py new file mode 100644 index 000000000..2152e94c7 --- /dev/null +++ b/tests/e2e/integrations/test_sonicmoe.py @@ -0,0 +1,288 @@ +""" +End-to-end gradient and convergence tests for SonicMoE integration. + +Requires: + - H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90) + - sonicmoe package installed + - transformers with Qwen3MoE support + +Usage: + pytest tests/e2e/integrations/test_sonicmoe.py -v -s +""" + +import importlib.util +import math + +import pytest +import torch + +_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None +_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"), + pytest.mark.skipif( + not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)" + ), + pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"), +] + + +def _create_tiny_qwen3_config(): + """Create a minimal Qwen3MoE config for fast testing.""" + from transformers import AutoConfig + + config = AutoConfig.for_model("qwen3_moe") + config.hidden_size = 512 + config.intermediate_size = 1024 + config.moe_intermediate_size = 64 + config.num_attention_heads = 16 + config.num_key_value_heads = 2 + config.head_dim = 32 + config.num_hidden_layers = 2 + config.num_experts = 8 + config.num_experts_per_tok = 2 + config.vocab_size = 1000 + config.max_position_embeddings = 128 + config.norm_topk_prob = True + config.torch_dtype = torch.bfloat16 + return config + + +def _interleave_gate_up_weights(model): + """Interleave all gate_up_proj parameters in-place for SonicMoE.""" + from axolotl.integrations.kernels.sonicmoe.weight_converter import ( + interleave_gate_up, + ) + + with torch.no_grad(): + for name, param in model.named_parameters(): + if "gate_up_proj" in name: + param.copy_(interleave_gate_up(param)) + + +def _unpatch_sonicmoe(): + """Restore original forward on the MoE block class if it was patched.""" + from axolotl.integrations.kernels.constants import resolve_moe_block_classes + + for moe_cls in resolve_moe_block_classes("qwen3_moe"): + if hasattr(moe_cls, "_original_forward"): + moe_cls.forward = moe_cls._original_forward + del moe_cls._original_forward + + +class TestSonicMoEForwardCorrectness: + """Verify SonicMoE-patched model produces same output as original.""" + + def teardown_method(self): + _unpatch_sonicmoe() + + def test_forward_output_matches(self): + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") + + # Original model + model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + + with torch.no_grad(): + out_orig = model_orig(input_ids) + + # Patched model (same weights, interleaved for SonicMoE) + model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + model_patched.load_state_dict(model_orig.state_dict()) + + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model_patched) + + with torch.no_grad(): + out_patched = model_patched(input_ids) + + max_diff = (out_orig.logits - out_patched.logits).abs().max().item() + assert torch.allclose( + out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1 + ), f"Output mismatch: max diff={max_diff:.6f}" + + +class TestSonicMoEGradientCorrectness: + """Compare gradients between original HuggingFace and SonicMoE-patched forward.""" + + def teardown_method(self): + _unpatch_sonicmoe() + + def test_gradients_match(self): + """Verify all parameter gradients match between original and patched.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + from axolotl.integrations.kernels.sonicmoe.weight_converter import ( + deinterleave_gate_up, + ) + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") + + # ---------- Original model ---------- + model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + out_orig = model_orig(input_ids, labels=input_ids) + out_orig.loss.backward() + grads_orig = { + n: p.grad.float().clone() + for n, p in model_orig.named_parameters() + if p.grad is not None + } + loss_orig = out_orig.loss.item() + + # ---------- SonicMoE-patched model (same weights, interleaved) ---------- + model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + model_patched.load_state_dict(model_orig.state_dict()) + + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model_patched) + + out_patched = model_patched(input_ids, labels=input_ids) + out_patched.loss.backward() + grads_patched = {} + for n, p in model_patched.named_parameters(): + if p.grad is None: + continue + g = p.grad.float().clone() + # gate_up_proj grads are in interleaved layout, de-interleave to match orig + if "gate_up_proj" in n: + g = deinterleave_gate_up(g) + grads_patched[n] = g + loss_patched = out_patched.loss.item() + + # ---------- Compare ---------- + assert abs(loss_orig - loss_patched) < 0.5, ( + f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}" + ) + + # All parameters with gradients in original should have them in patched + missing = set(grads_orig.keys()) - set(grads_patched.keys()) + assert not missing, f"Missing gradients in patched model: {missing}" + + # Compare gradient values + # bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge, + # so use generous tolerance: flag only if both rel >10% AND abs >1e-2 + mismatches = [] + for name in grads_orig: + if name not in grads_patched: + continue + g_orig = grads_orig[name] + g_patched = grads_patched[name] + max_diff = (g_orig - g_patched).abs().max().item() + rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8) + + if rel_diff > 0.1 and max_diff > 1e-2: + mismatches.append( + f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}" + ) + + assert not mismatches, ( + "Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n" + + "\n".join(mismatches) + ) + + def test_router_weights_receive_gradients(self): + """Verify that router (gate) weights get non-zero gradients.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + + out = model(input_ids, labels=input_ids) + out.loss.backward() + + gate_grads_found = False + for name, param in model.named_parameters(): + if "gate" in name and "weight" in name: + gate_grads_found = True + assert param.grad is not None, f"No gradient for router: {name}" + assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}" + + assert gate_grads_found, "No gate.weight parameters found in model" + + +class TestSonicMoETrainingConvergence: + """Verify loss decreases during training with SonicMoE.""" + + def teardown_method(self): + _unpatch_sonicmoe() + + def test_loss_decreases(self): + """Run 30 training steps, verify loss decreases and no NaN/Inf.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + losses = [] + + for step in range(30): + out = model(input_ids, labels=input_ids) + loss = out.loss + assert not math.isnan(loss.item()), f"NaN loss at step {step}" + assert not math.isinf(loss.item()), f"Inf loss at step {step}" + losses.append(loss.item()) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + assert losses[-1] < losses[0], ( + f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}" + ) + + def test_expert_weights_update(self): + """Verify expert weights change during training (not frozen).""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + + # Snapshot expert weights before training + expert_weights_before = {} + for name, param in model.named_parameters(): + if "experts" in name: + expert_weights_before[name] = param.data.clone() + + assert expert_weights_before, "No expert parameters found" + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + for _ in range(5): + out = model(input_ids, labels=input_ids) + out.loss.backward() + optimizer.step() + optimizer.zero_grad() + + # Check that expert weights changed + changed = 0 + for name, param in model.named_parameters(): + if name in expert_weights_before: + if not torch.equal(param.data, expert_weights_before[name]): + changed += 1 + + assert changed > 0, "No expert weights changed after 5 training steps" diff --git a/tests/integrations/test_scattermoe_lora.py b/tests/integrations/test_scattermoe_lora.py index 859119c81..d498c8010 100644 --- a/tests/integrations/test_scattermoe_lora.py +++ b/tests/integrations/test_scattermoe_lora.py @@ -6,7 +6,7 @@ Unit tests for scattermoe-lora code-review fixes. Tests cover: -- KernelsArgs validator: disable_mlp_kernel_scattermoe +- KernelsArgs validator: disable_mlp_kernel - CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward - ParallelExperts: scaling=0.0 not treated as falsy - single2scatter: non-aligned K/N dimensions @@ -20,12 +20,12 @@ import pytest import torch # ============================================================================ -# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator +# 1. KernelsArgs: disable_mlp_kernel validator # ============================================================================ class TestKernelsArgsValidator: - """Test that disable_mlp_kernel_scattermoe sets both flags correctly. + """Test that disable_mlp_kernel sets both flags correctly. These tests call the validator classmethod directly on raw dicts, since lora_mlp_kernel / mlp_kernel are not declared model fields. @@ -40,7 +40,7 @@ class TestKernelsArgsValidator: "use_scattermoe": True, "lora_mlp_kernel": True, } - result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + result = KernelsArgs.disable_mlp_kernel(data) assert result["lora_mlp_kernel"] is False assert result["mlp_kernel"] is False @@ -52,7 +52,7 @@ class TestKernelsArgsValidator: "use_kernels": True, "use_scattermoe": True, } - result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + result = KernelsArgs.disable_mlp_kernel(data) assert result["mlp_kernel"] is False # lora_mlp_kernel was not in data, should not be added assert "lora_mlp_kernel" not in result @@ -66,7 +66,7 @@ class TestKernelsArgsValidator: "use_scattermoe": True, "lora_mlp_kernel": False, } - result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + result = KernelsArgs.disable_mlp_kernel(data) assert result["lora_mlp_kernel"] is False def test_no_change_when_scattermoe_disabled(self): @@ -78,7 +78,7 @@ class TestKernelsArgsValidator: "use_scattermoe": False, "lora_mlp_kernel": True, } - result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + result = KernelsArgs.disable_mlp_kernel(data) assert result["lora_mlp_kernel"] is True diff --git a/tests/integrations/test_sonicmoe.py b/tests/integrations/test_sonicmoe.py new file mode 100644 index 000000000..e6294f564 --- /dev/null +++ b/tests/integrations/test_sonicmoe.py @@ -0,0 +1,428 @@ +"""Unit tests for the SonicMoE integration.""" + +from types import SimpleNamespace + +import pytest +import torch + +from axolotl.integrations.kernels.args import KernelsArgs +from axolotl.integrations.kernels.sonicmoe.routing import ( + sigmoid_topk_routing, + softmax_topk_routing, +) +from axolotl.integrations.kernels.sonicmoe.weight_converter import ( + ConcatenatedToInterleaved, + InterleavedToConcatenated, + register_sonicmoe_weight_converter, +) + + +class TestKernelsArgs: + def test_mutual_exclusivity_raises(self): + with pytest.raises(ValueError, match="Cannot use both"): + KernelsArgs.model_validate({"use_scattermoe": True, "use_sonicmoe": True}) + + def test_sonicmoe_only(self): + result = KernelsArgs.model_validate({"use_sonicmoe": True}) + assert result.use_sonicmoe is True + assert result.use_scattermoe is None + + def test_scattermoe_only(self): + result = KernelsArgs.model_validate({"use_scattermoe": True}) + assert result.use_scattermoe is True + assert result.use_sonicmoe is None + + def test_neither_set(self): + result = KernelsArgs.model_validate({}) + assert result.use_scattermoe is None + assert result.use_sonicmoe is None + + def test_disables_mlp_kernel_when_sonicmoe(self): + data = {"use_sonicmoe": True, "lora_mlp_kernel": True} + result = KernelsArgs.disable_mlp_kernel(data) + assert result["lora_mlp_kernel"] is False + assert result["mlp_kernel"] is False + + +class TestConcatenatedToInterleaved: + @pytest.fixture + def sample_tensor(self): + """Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values.""" + E, I, H = 2, 2, 3 # noqa: E741 + gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H) + up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H) + return torch.cat([gate, up], dim=1) + + def test_interleave_rows_alternate(self, sample_tensor): + op = ConcatenatedToInterleaved(dim=1) + result = op.convert( + {"test": sample_tensor}, + source_patterns=["test"], + target_patterns=["test"], + ) + interleaved = result["test"] + + # For expert 0: even rows should be gate, odd rows should be up + E, two_I, H = sample_tensor.shape + I = two_I // 2 # noqa: E741 + gate_orig = sample_tensor[:, :I, :] + up_orig = sample_tensor[:, I:, :] + + assert torch.equal(interleaved[:, 0::2, :], gate_orig) + assert torch.equal(interleaved[:, 1::2, :], up_orig) + + def test_interleave_handles_list_input(self, sample_tensor): + op = ConcatenatedToInterleaved(dim=1) + result = op.convert( + {"test": [sample_tensor]}, + source_patterns=["test"], + target_patterns=["test"], + ) + assert result["test"].shape == sample_tensor.shape + + def test_reverse_op_type(self): + op = ConcatenatedToInterleaved(dim=1) + assert isinstance(op.reverse_op, InterleavedToConcatenated) + assert op.reverse_op.dim == 1 + + +class TestInterleavedToConcatenated: + @pytest.fixture + def interleaved_tensor(self): + """Create an interleaved tensor [E=2, 2*I=4, H=3].""" + E, I, H = 2, 2, 3 # noqa: E741 + gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H) + up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H) + interleaved = torch.empty(E, 2 * I, H) + interleaved[:, 0::2, :] = gate + interleaved[:, 1::2, :] = up + return interleaved + + def test_deinterleave_gate_up_separated(self, interleaved_tensor): + op = InterleavedToConcatenated(dim=1) + result = op.convert( + {"test": interleaved_tensor}, + source_patterns=["test"], + target_patterns=["test"], + ) + concatenated = result["test"] + + E, two_I, H = concatenated.shape + I = two_I // 2 # noqa: E741 + + # First half should be gate (even rows from interleaved) + assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :]) + # Second half should be up (odd rows from interleaved) + assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :]) + + def test_reverse_op_type(self): + op = InterleavedToConcatenated(dim=1) + assert isinstance(op.reverse_op, ConcatenatedToInterleaved) + assert op.reverse_op.dim == 1 + + +class TestRoundTrip: + @pytest.fixture + def concat_tensor(self): + E, I, H = 4, 8, 16 # noqa: E741 + gate = torch.randn(E, I, H) + up = torch.randn(E, I, H) + return torch.cat([gate, up], dim=1) + + def test_interleave_then_deinterleave_is_identity(self, concat_tensor): + fwd = ConcatenatedToInterleaved(dim=1) + rev = InterleavedToConcatenated(dim=1) + + interleaved = fwd.convert( + {"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"] + )["k"] + recovered = rev.convert( + {"k": interleaved}, source_patterns=["k"], target_patterns=["k"] + )["k"] + + assert torch.equal(concat_tensor, recovered) + + def test_reverse_op_chain_is_identity(self, concat_tensor): + """Verify that op.reverse_op produces an exact inverse.""" + op = ConcatenatedToInterleaved(dim=1) + rev = op.reverse_op + + interleaved = op.convert( + {"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"] + )["k"] + recovered = rev.convert( + {"k": interleaved}, source_patterns=["k"], target_patterns=["k"] + )["k"] + + assert torch.equal(concat_tensor, recovered) + + def test_various_shapes(self): + """Test with different expert counts and dimensions.""" + fwd = ConcatenatedToInterleaved(dim=1) + rev = InterleavedToConcatenated(dim=1) + + for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741 + concat = torch.randn(E, 2 * I, H) + interleaved = fwd.convert( + {"k": concat}, source_patterns=["k"], target_patterns=["k"] + )["k"] + recovered = rev.convert( + {"k": interleaved}, source_patterns=["k"], target_patterns=["k"] + )["k"] + assert torch.equal(concat, recovered), ( + f"Failed for shape ({E}, {2 * I}, {H})" + ) + + +class TestWeightConverterRegistration: + def test_register_appends_interleave_op(self): + from transformers.conversion_mapping import get_checkpoint_conversion_mapping + + register_sonicmoe_weight_converter("qwen3_moe") + + modified = get_checkpoint_conversion_mapping("qwen3_moe") + # Find the gate_up_proj converter + gate_up_converter = None + for conv in modified: + if hasattr(conv, "operations") and any( + "gate_up_proj" in pat for pat in conv.target_patterns + ): + gate_up_converter = conv + break + + assert gate_up_converter is not None + assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved) + + def test_double_registration_is_idempotent(self): + from transformers.conversion_mapping import get_checkpoint_conversion_mapping + + register_sonicmoe_weight_converter("qwen3_moe") + register_sonicmoe_weight_converter("qwen3_moe") + + modified = get_checkpoint_conversion_mapping("qwen3_moe") + for conv in modified: + if hasattr(conv, "operations") and any( + "gate_up_proj" in pat for pat in conv.target_patterns + ): + interleave_count = sum( + isinstance(op, ConcatenatedToInterleaved) for op in conv.operations + ) + assert interleave_count == 1, ( + f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}" + ) + break + + def test_register_unsupported_model_type_warns(self): + # A model type with no conversion mapping should warn but not raise + register_sonicmoe_weight_converter("nonexistent_model_type_xyz") + + +def _make_qwen_moe_block(T=8, H=16, E=4, K=2): + """Create a mock qwen-style MoE block for routing tests.""" + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + return SimpleNamespace(gate=gate), T, H, E, K + + +def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1): + """Create a mock GLM5-style MoE block for routing tests.""" + gate = SimpleNamespace( + weight=torch.randn(E, H), + e_score_correction_bias=torch.zeros(E), + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + n_routed_experts=E, + n_group=n_group, + topk_group=topk_group, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + return moe_block, T, H, E, K + + +def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4): + """Create a mock minimax_m2-style MoE block for routing tests. + + minimax_m2 uses sigmoid->topk WITHOUT group selection: + - e_score_correction_bias is on the moe_block (not on gate) + - No n_group / topk_group attributes + - Always normalizes (norm_topk_prob defaults to True) + - No routed_scaling_factor (defaults to 1.0) + """ + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + e_score_correction_bias=torch.zeros(E), + ) + return moe_block, T, H, E, K + + +class TestSoftmaxTopkRouting: + def test_output_shapes(self): + moe_block, T, H, E, K = _make_qwen_moe_block() + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E) + + def test_scores_are_float32(self): + moe_block, T, H, E, K = _make_qwen_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_topk_routing(hidden, moe_block) + assert scores.dtype == torch.float32 + + def test_token_indices_sorted_ascending(self): + moe_block, T, H, E, K = _make_qwen_moe_block() + hidden = torch.randn(T, H) + + _, token_idx, _, _ = softmax_topk_routing(hidden, moe_block) + + # Token indices must be sorted ascending (SonicMoE requirement) + diffs = token_idx[1:] - token_idx[:-1] + assert (diffs >= 0).all() + + def test_expert_indices_in_range(self): + moe_block, T, H, E, K = _make_qwen_moe_block() + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block) + + assert (expert_idx >= 0).all() + assert (expert_idx < E).all() + + def test_renormalized_scores_sum_to_one(self): + moe_block, T, H, E, K = _make_qwen_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_topk_routing(hidden, moe_block) + per_token_sums = scores.reshape(T, K).sum(dim=-1) + assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5) + + +class TestSigmoidTopkRouting: + def test_output_shapes(self): + moe_block, T, H, E, K = _make_glm_moe_block() + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E) + + def test_scores_are_float32(self): + moe_block, T, H, E, K = _make_glm_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + assert scores.dtype == torch.float32 + + def test_token_indices_sorted_ascending(self): + moe_block, T, H, E, K = _make_glm_moe_block() + hidden = torch.randn(T, H) + + _, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block) + + diffs = token_idx[1:] - token_idx[:-1] + assert (diffs >= 0).all() + + def test_expert_indices_in_range(self): + moe_block, T, H, E, K = _make_glm_moe_block() + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + assert (expert_idx >= 0).all() + assert (expert_idx < E).all() + + def test_scores_are_nonnegative(self): + """Sigmoid outputs are in [0, 1], so scores should be non-negative.""" + moe_block, T, H, E, K = _make_glm_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + assert (scores >= 0).all() + + def test_scaling_factor_applied(self): + moe_block, T, H, E, K = _make_glm_moe_block() + hidden = torch.randn(T, H) + + # Get scores with scaling_factor=1.0 + scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + + # Get scores with scaling_factor=2.0 + moe_block.routed_scaling_factor = 2.0 + scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + + assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5) + + def test_group_selection_restricts_experts(self): + """With n_group=4 and topk_group=1, only 1/4 of experts should be selectable.""" + moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1) + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + # Each token's experts should all fall within a single group (size E//n_group=4) + expert_idx_2d = expert_idx.reshape(T, K) + for t in range(T): + experts = expert_idx_2d[t] + groups = experts // (E // moe_block.n_group) + # All selected experts should be from the same group + assert (groups == groups[0]).all() + + +class TestMiniMaxM2SigmoidRouting: + """Tests for minimax_m2 routing: sigmoid->topk without group selection.""" + + def test_output_shapes(self): + """Validates getattr defaults work: n_group=1, E from gate.weight.shape[0].""" + moe_block, T, H, E, K = _make_minimax_m2_moe_block() + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E) + + def test_bias_on_block_not_gate(self): + """Verify that e_score_correction_bias on the block (not gate) is used.""" + T, H, E, K = 8, 16, 8, 2 + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + ) + # Large positive bias on expert 0 should make it selected more often + bias = torch.zeros(E) + bias[0] = 100.0 + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + e_score_correction_bias=bias, + ) + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + # Expert 0 should appear for every token due to the large bias + expert_idx_2d = expert_idx.reshape(T, K) + for t in range(T): + assert 0 in expert_idx_2d[t] diff --git a/tests/integrations/test_sonicmoe_gradients.py b/tests/integrations/test_sonicmoe_gradients.py new file mode 100644 index 000000000..e76bdd480 --- /dev/null +++ b/tests/integrations/test_sonicmoe_gradients.py @@ -0,0 +1,158 @@ +""" +Gradient correctness tests for SonicMoE routing functions (CPU-only). + +Uses torch.autograd.gradcheck with float32 inputs to match the production +code path where routing happens in float32. +""" + +import torch + +from axolotl.integrations.kernels.sonicmoe.routing import ( + sigmoid_topk_routing, + softmax_topk_routing, +) + +_GC_EPS = 1e-3 +_GC_ATOL = 1e-3 +_GC_RTOL = 1e-3 + + +def _make_softmax_moe_block(weight): + gate = torch.nn.Module() + gate.weight = weight + gate.top_k = 2 + gate.norm_topk_prob = True + + moe_block = torch.nn.Module() + moe_block.gate = gate + return moe_block + + +def _make_sigmoid_moe_block(weight, bias): + gate = torch.nn.Module() + gate.weight = weight + gate.e_score_correction_bias = bias + + moe_block = torch.nn.Module() + moe_block.gate = gate + moe_block.top_k = 2 + moe_block.n_routed_experts = weight.shape[0] + moe_block.n_group = 1 + moe_block.norm_topk_prob = True + moe_block.routed_scaling_factor = 1.0 + return moe_block + + +class TestSoftmaxTopkRoutingGradcheck: + """Numerical gradient verification for softmax_topk_routing.""" + + def test_gradcheck_wrt_gate_weight(self): + T, H, E = 4, 8, 4 + + hidden = torch.randn(T, H, dtype=torch.float32) + + def fn(weight): + moe_block = _make_softmax_moe_block(weight) + scores, _, _, _ = softmax_topk_routing(hidden, moe_block) + return scores + + weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck( + fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL + ) + + def test_gradcheck_wrt_hidden_states(self): + T, H, E = 4, 8, 4 + + weight = torch.randn(E, H, dtype=torch.float32) + moe_block = _make_softmax_moe_block(weight) + + def fn(hidden): + scores, _, _, _ = softmax_topk_routing(hidden, moe_block) + return scores + + hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck( + fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL + ) + + def test_gradcheck_wrt_router_logits(self): + T, H, E = 4, 8, 4 + + hidden = torch.randn(T, H, dtype=torch.float32) + + def fn(weight): + moe_block = _make_softmax_moe_block(weight) + _, _, _, router_logits = softmax_topk_routing(hidden, moe_block) + return router_logits + + weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck( + fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL + ) + + def test_no_norm_variant(self): + T, H, E = 4, 8, 4 + + hidden = torch.randn(T, H, dtype=torch.float32) + + def fn(weight): + moe_block = _make_softmax_moe_block(weight) + moe_block.gate.norm_topk_prob = False + scores, _, _, _ = softmax_topk_routing(hidden, moe_block) + return scores + + weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck( + fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL + ) + + +class TestSigmoidTopkRoutingGradcheck: + """Numerical gradient verification for sigmoid_topk_routing.""" + + def test_gradcheck_wrt_gate_weight(self): + T, H, E = 4, 8, 4 + + hidden = torch.randn(T, H, dtype=torch.float32) + bias = torch.zeros(E, dtype=torch.float32) + + def fn(weight): + moe_block = _make_sigmoid_moe_block(weight, bias) + scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + return scores + + weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck( + fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL + ) + + def test_gradcheck_wrt_hidden_states(self): + T, H, E = 4, 8, 4 + + weight = torch.randn(E, H, dtype=torch.float32) + bias = torch.zeros(E, dtype=torch.float32) + moe_block = _make_sigmoid_moe_block(weight, bias) + + def fn(hidden): + scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + return scores + + hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck( + fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL + ) + + def test_gradcheck_wrt_bias(self): + T, H, E = 4, 8, 4 + + hidden = torch.randn(T, H, dtype=torch.float32) + weight = torch.randn(E, H, dtype=torch.float32) + + def fn(bias): + moe_block = _make_sigmoid_moe_block(weight, bias) + scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) + return scores + + bias = torch.zeros(E, dtype=torch.float32, requires_grad=True) + torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)