consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
This commit is contained in:
120
src/axolotl/integrations/kernels/autotune_callback.py
Normal file
120
src/axolotl/integrations/kernels/autotune_callback.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# Give up looking for autotune data after this many training steps.
|
||||
_MAX_POLL_STEP = 5
|
||||
|
||||
|
||||
def _get_gpu_info() -> dict:
|
||||
"""Return basic GPU identification for the current device."""
|
||||
if not torch.cuda.is_available():
|
||||
return {}
|
||||
try:
|
||||
idx = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(idx)
|
||||
return {
|
||||
"gpu_name": props.name,
|
||||
"gpu_compute_capability": f"{props.major}.{props.minor}",
|
||||
"gpu_memory_bytes": props.total_memory,
|
||||
}
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return {}
|
||||
|
||||
|
||||
def _get_smem_capacity() -> dict:
|
||||
"""Return shared memory capacity from the runtime lora_ops module."""
|
||||
try:
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
_find_lora_ops_module,
|
||||
)
|
||||
|
||||
lora_ops = _find_lora_ops_module()
|
||||
if lora_ops is None:
|
||||
return {}
|
||||
fn = getattr(lora_ops, "_get_smem_capacity", None)
|
||||
if fn is None:
|
||||
return {}
|
||||
return {"smem_capacity_bytes": fn()}
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return {}
|
||||
|
||||
|
||||
class AutotuneReportCallback(TrainerCallback):
|
||||
"""Reports Triton kernel autotune selections via telemetry.
|
||||
|
||||
Fires **once** after the first training step completes (step 1), at
|
||||
which point the forward and backward passes have both run and the
|
||||
autotuned kernels have populated their caches. If for some reason
|
||||
the caches are still empty (e.g. the kernel was never invoked), the
|
||||
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
|
||||
then stops polling.
|
||||
|
||||
After reporting (or giving up) every subsequent ``on_step_end``
|
||||
call short-circuits on the ``_reported`` flag — zero hot-path cost.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._reported = False
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if self._reported:
|
||||
return
|
||||
|
||||
# Lazy import — Triton / scattermoe kernels may not be installed.
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
configs = collect_autotune_configs()
|
||||
|
||||
if not configs:
|
||||
if state.global_step >= _MAX_POLL_STEP:
|
||||
LOG.debug(
|
||||
"No autotune data found after %d steps; giving up.",
|
||||
state.global_step,
|
||||
)
|
||||
self._reported = True
|
||||
return
|
||||
|
||||
self._reported = True
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if not telemetry_manager.enabled:
|
||||
return
|
||||
|
||||
properties = {
|
||||
"kernel_count": len(configs),
|
||||
"kernels": configs,
|
||||
}
|
||||
properties.update(_get_gpu_info())
|
||||
properties.update(_get_smem_capacity())
|
||||
|
||||
telemetry_manager.send_event(
|
||||
event_type="scattermoe-autotune",
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
|
||||
len(configs),
|
||||
)
|
||||
114
src/axolotl/integrations/kernels/autotune_collector.py
Normal file
114
src/axolotl/integrations/kernels/autotune_collector.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Collect Triton autotune results from scattermoe-lora kernels.
|
||||
|
||||
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
|
||||
decorated kernel objects and returns structured dicts describing the selected
|
||||
configurations. It has **no** telemetry dependency — callers decide what to
|
||||
do with the data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# (human-readable name, attribute on the lora_ops module)
|
||||
_KERNEL_REGISTRY: list[tuple[str, str]] = [
|
||||
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
|
||||
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
|
||||
("group_bwd_lora", "_group_bwd_lora"),
|
||||
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
|
||||
]
|
||||
|
||||
# The autotune key declared on every kernel: key=["M", "N", "K"]
|
||||
_KEY_NAMES: list[str] = ["M", "N", "K"]
|
||||
|
||||
|
||||
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
|
||||
"""Turn the autotune cache key tuple into a labelled dict.
|
||||
|
||||
Triton builds the cache key from the values of the declared ``key``
|
||||
args (``M``, ``N``, ``K``) followed by dtype signature elements.
|
||||
We label the first three and store the rest under ``_extra``.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for i, name in enumerate(_KEY_NAMES):
|
||||
if i < len(key_tuple):
|
||||
result[name] = key_tuple[i]
|
||||
if len(key_tuple) > len(_KEY_NAMES):
|
||||
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
|
||||
return result
|
||||
|
||||
|
||||
def _find_lora_ops_module() -> ModuleType | None:
|
||||
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
|
||||
|
||||
The HF ``kernels`` package loads ``scattermoe_lora`` via
|
||||
``import_from_path`` which registers it in ``sys.modules`` under a
|
||||
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
|
||||
import (``from axolotl.integrations.kernels...``) would create a
|
||||
*separate* module instance whose kernel objects have empty
|
||||
``.cache`` dicts because autotuning ran on the runtime copy.
|
||||
|
||||
We search ``sys.modules`` for any module whose name contains
|
||||
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
|
||||
attribute — that is the runtime copy with populated caches.
|
||||
"""
|
||||
for name, module in sys.modules.items():
|
||||
if (
|
||||
module is not None
|
||||
and "lora_ops" in name
|
||||
and hasattr(module, "_scatter2scatter_lora")
|
||||
):
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def collect_autotune_configs() -> list[dict[str, Any]]:
|
||||
"""Read autotune caches from the four scattermoe-lora kernels.
|
||||
|
||||
Returns a (possibly empty) list of dicts, each containing:
|
||||
|
||||
* ``kernel`` – human-readable kernel name
|
||||
* ``key`` – dict with the ``M``/``N``/``K`` problem dimensions
|
||||
* ``config`` – dict with the selected tile sizes, ``num_warps``,
|
||||
and ``num_stages``
|
||||
|
||||
Returns ``[]`` if the kernel module cannot be found or if no
|
||||
autotune cache entries exist yet.
|
||||
"""
|
||||
lora_ops = _find_lora_ops_module()
|
||||
if lora_ops is None:
|
||||
LOG.debug(
|
||||
"lora_ops module not found in sys.modules; skipping autotune collection"
|
||||
)
|
||||
return []
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for friendly_name, attr_name in _KERNEL_REGISTRY:
|
||||
kernel_fn = getattr(lora_ops, attr_name, None)
|
||||
if kernel_fn is None:
|
||||
continue
|
||||
|
||||
cache = getattr(kernel_fn, "cache", None)
|
||||
if not cache:
|
||||
continue
|
||||
|
||||
for key_tuple, config in cache.items():
|
||||
config_dict = dict(config.kwargs)
|
||||
config_dict["num_warps"] = config.num_warps
|
||||
config_dict["num_stages"] = config.num_stages
|
||||
if getattr(config, "num_ctas", None) is not None:
|
||||
config_dict["num_ctas"] = config.num_ctas
|
||||
|
||||
results.append(
|
||||
{
|
||||
"kernel": friendly_name,
|
||||
"key": _parse_key_tuple(key_tuple),
|
||||
"config": config_dict,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
|
||||
return base_experts, gup_lora, down_lora
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Routing helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _softmax_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
):
|
||||
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
|
||||
|
||||
Returns:
|
||||
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
|
||||
"""
|
||||
router_logits = F.linear(hidden_states, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
|
||||
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return routing_weights, selected_experts, top_k, num_experts
|
||||
|
||||
|
||||
def _sigmoid_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
):
|
||||
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
|
||||
|
||||
Supports:
|
||||
- ``e_score_correction_bias`` on gate or moe_block
|
||||
- Group-based expert selection when ``n_group > 1``
|
||||
- ``routed_scaling_factor`` applied to final weights
|
||||
- Final weights gathered from original sigmoid probs (not bias-corrected)
|
||||
|
||||
Returns:
|
||||
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
|
||||
"""
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float())
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
|
||||
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
|
||||
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is not None:
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
else:
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, num_experts // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [T, n_group]
|
||||
topk_group = getattr(moe_block, "topk_group", n_group)
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(-1, n_group, num_experts // n_group)
|
||||
.reshape(-1, num_experts)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
|
||||
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
if getattr(moe_block, "norm_topk_prob", True):
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_indices, top_k, num_experts
|
||||
|
||||
|
||||
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
||||
"""Dispatch to the correct routing strategy based on block attributes.
|
||||
|
||||
Detects sigmoid routing by the presence of ``e_score_correction_bias``
|
||||
on either the gate or the moe_block.
|
||||
"""
|
||||
has_sigmoid = (
|
||||
getattr(base_gate, "e_score_correction_bias", None) is not None
|
||||
or getattr(moe_block, "e_score_correction_bias", None) is not None
|
||||
)
|
||||
if has_sigmoid:
|
||||
return _sigmoid_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
)
|
||||
return _softmax_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared expert helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _compute_shared_expert(moe_block, hidden_states_flat):
|
||||
"""Compute shared expert output if the block has one.
|
||||
|
||||
Handles singular (qwen2_moe: ``shared_expert``), plural
|
||||
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
|
||||
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
|
||||
|
||||
peft wraps individual linear layers inside the shared expert with
|
||||
standard LoRA — calling forward() handles this transparently.
|
||||
"""
|
||||
shared_expert = (
|
||||
getattr(moe_block, "shared_expert", None)
|
||||
or getattr(moe_block, "shared_experts", None)
|
||||
or getattr(moe_block, "shared_mlp", None)
|
||||
)
|
||||
if shared_expert is None:
|
||||
return None
|
||||
|
||||
shared_expert_output = shared_expert(hidden_states_flat)
|
||||
|
||||
# Optional sigmoid gate (Qwen2MoE pattern).
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
|
||||
if shared_expert_gate is not None:
|
||||
shared_expert_output = (
|
||||
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
|
||||
)
|
||||
|
||||
return shared_expert_output
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Layer classes
|
||||
# =============================================================================
|
||||
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
|
||||
|
||||
class HFScatterMoEGatedMLP(nn.Module):
|
||||
"""
|
||||
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
|
||||
ScatterMoE-accelerated forward pass for HF MoEs.
|
||||
|
||||
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
|
||||
method replaces the original ``OlmoeSparseMoeBlock.forward``.
|
||||
method replaces the original SparseMoeBlock.forward.
|
||||
|
||||
Supports both full-parameter training and LoRA fine-tuning:
|
||||
Supports:
|
||||
|
||||
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
|
||||
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
|
||||
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
self: The MoeSparseMoeBlock module containing:
|
||||
- self.gate: Router (or peft ParamWrapper wrapping it)
|
||||
- self.experts: Experts module (or peft ParamWrapper chain)
|
||||
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
|
||||
- self.shared_expert(s): Optional shared expert
|
||||
- self.shared_expert_gate: Optional shared expert gate
|
||||
layer_input: Input tensor [batch_size, seq_len, hidden_size]
|
||||
|
||||
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# ====================================================================
|
||||
# peft wraps individual linear layers inside shared_expert with
|
||||
# standard LoRA — calling forward() handles this transparently.
|
||||
if hasattr(self, "shared_expert") and self.shared_expert is not None:
|
||||
shared_expert_output = self.shared_expert(hidden_states_flat)
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states_flat)
|
||||
)
|
||||
shared_expert_output = shared_expert_output * shared_expert_gate_output
|
||||
else:
|
||||
shared_expert_output = None
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Router Computation (with optional gate LoRA)
|
||||
# ====================================================================
|
||||
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
|
||||
router_logits = F.linear(hidden_states_flat, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states_flat, gate_lora_delta
|
||||
)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if base_gate.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights, selected_experts, top_k, num_experts = _route(
|
||||
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
|
||||
)
|
||||
routing_weights = routing_weights.to(hidden_states_flat.dtype)
|
||||
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
|
||||
@@ -110,6 +110,16 @@ class KernelsPlugin(BasePlugin):
|
||||
}
|
||||
)
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
callbacks = []
|
||||
if cfg.use_scattermoe:
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
callbacks.append(AutotuneReportCallback())
|
||||
return callbacks
|
||||
|
||||
def _kernelize_model(self, model_type: str):
|
||||
from kernels import replace_kernel_forward_from_hub
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests verify correctness of:
|
||||
3. Frozen weights: expert weight gradients are correctly skipped
|
||||
4. Various configurations: top-k, grouped_in/out, with/without bias
|
||||
5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference
|
||||
6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2)
|
||||
|
||||
Test strategy:
|
||||
- Reference implementation uses pure PyTorch ops (no Triton)
|
||||
@@ -19,6 +20,8 @@ Test strategy:
|
||||
- Tolerances account for tf32 accumulation in Triton kernels
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -1476,3 +1479,347 @@ class TestCombinedOptimizations:
|
||||
torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: HFScatterMoEGatedMLP with Sigmoid Routing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _reference_moe_forward(
|
||||
hidden_states,
|
||||
gate_weight,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
num_experts,
|
||||
):
|
||||
"""Pure PyTorch reference for a full MoE forward pass.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H]
|
||||
gate_weight: [E, H]
|
||||
gate_up_proj: [E, 2*FF, H]
|
||||
down_proj: [E, H, FF]
|
||||
act_fn: activation function (e.g. torch.nn.SiLU())
|
||||
routing_weights: [T, K] routing weights
|
||||
selected_experts: [T, K] expert indices
|
||||
num_experts: int
|
||||
|
||||
Returns:
|
||||
output: [T, H]
|
||||
"""
|
||||
T, H = hidden_states.shape
|
||||
K = selected_experts.shape[1]
|
||||
output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
for t in range(T):
|
||||
for j in range(K):
|
||||
e = selected_experts[t, j].item()
|
||||
w = routing_weights[t, j].item()
|
||||
|
||||
# gate_up projection
|
||||
gup = hidden_states[t] @ gate_up_proj[e].T # [2*I]
|
||||
I_dim = gup.shape[0] // 2
|
||||
gates = gup[:I_dim]
|
||||
up = gup[I_dim:]
|
||||
|
||||
# activation
|
||||
h = act_fn(gates) * up
|
||||
|
||||
# down projection
|
||||
out = h @ down_proj[e].T # [H]
|
||||
|
||||
output[t] += w * out
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""Create a mock MoE block with sigmoid routing for GPU testing."""
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
|
||||
return moe_block, T, H, FF, E, K
|
||||
|
||||
|
||||
class TestHFScatterMoESigmoidRouting:
|
||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||
|
||||
def test_forward_matches_reference_bias_on_gate(self):
|
||||
"""Forward pass with sigmoid routing (bias on gate) matches reference."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
|
||||
# Get routing for reference
|
||||
gate = moe_block.gate
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
# Reference output
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
moe_block.experts.gate_up_proj,
|
||||
moe_block.experts.down_proj,
|
||||
moe_block.experts.act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
# Kernel output
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
def test_forward_matches_reference_bias_on_block(self):
|
||||
"""Forward pass with sigmoid routing (minimax_m2 style, bias on block)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
|
||||
gate = moe_block.gate
|
||||
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
moe_block.experts.gate_up_proj,
|
||||
moe_block.experts.down_proj,
|
||||
moe_block.experts.act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
def test_softmax_routing_still_works(self):
|
||||
"""Verify softmax routing (Qwen/OLMoE) is not broken."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 16, 64, 32, 4, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate, experts=experts)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
|
||||
routing_weights, selected_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
|
||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
"""DeepSeek V3 style: shared_experts attribute (plural)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 8, 64, 32, 8, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
# Shared expert as a simple linear for testing
|
||||
shared_W = torch.randn(H, H, device="cuda") * 0.01
|
||||
shared_experts_fn = lambda x: x @ shared_W.T # noqa: E731
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
shared_experts=shared_experts_fn,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
|
||||
# Should not raise; output should include shared expert contribution
|
||||
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
assert output.shape == (1, T, H)
|
||||
|
||||
# Run without shared expert to verify it changes the output
|
||||
moe_block_no_shared = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden)
|
||||
assert not torch.equal(output, output_no_shared)
|
||||
|
||||
def test_shared_expert_with_gate(self):
|
||||
"""Qwen2MoE style: shared_expert + shared_expert_gate."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 8, 64, 32, 4, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
shared_W = torch.randn(H, H, device="cuda") * 0.01
|
||||
shared_expert_fn = lambda x: x @ shared_W.T # noqa: E731
|
||||
# Gate that returns 0 -> sigmoid(0) = 0.5
|
||||
gate_W = torch.zeros(H, H, device="cuda")
|
||||
shared_expert_gate_fn = lambda x: x @ gate_W.T # noqa: E731
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
shared_expert=shared_expert_fn,
|
||||
shared_expert_gate=shared_expert_gate_fn,
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
assert output.shape == (1, T, H)
|
||||
|
||||
474
tests/integrations/test_routing_parity.py
Normal file
474
tests/integrations/test_routing_parity.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Parity tests between scattermoe-lora and sonicmoe routing implementations.
|
||||
|
||||
These tests verify that both implementations produce numerically identical
|
||||
results for the same inputs, ensuring safe centralization of the routing code.
|
||||
|
||||
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
|
||||
The core algorithm should be identical — only the output format differs.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _require_triton():
|
||||
pytest.importorskip("triton")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures / helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_softmax_block(T=8, H=16, E=4, K=2):
|
||||
"""Qwen/OLMoE-style block usable by both implementations."""
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, H)
|
||||
return moe_block, gate, hidden, T, H, E, K
|
||||
|
||||
|
||||
def _make_sigmoid_block(
|
||||
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""GLM/DeepSeek-style block usable by both implementations."""
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style: bias on block
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, gate, hidden_states(T, H), T, H, E, K
|
||||
|
||||
|
||||
def hidden_states(T, H):
|
||||
return torch.randn(T, H)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 1. Softmax routing parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSoftmaxRoutingParity:
|
||||
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_weights_match(self):
|
||||
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
# ScatterMoE path (no LoRA delta)
|
||||
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# SonicMoE path
|
||||
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert sm_topk == K
|
||||
assert sm_E == E
|
||||
|
||||
# Both should select the same experts and produce the same weights
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
||||
|
||||
def test_logits_not_returned_by_scattermoe(self):
|
||||
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_no_renorm(self):
|
||||
"""With norm_topk_prob=False, both should skip renormalization."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
gate.norm_topk_prob = False
|
||||
|
||||
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
||||
|
||||
def test_various_expert_counts(self):
|
||||
"""Parity across different E and K values."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
|
||||
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
|
||||
f"Expert mismatch for E={E}, K={K}"
|
||||
)
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
|
||||
f"Weight mismatch for E={E}, K={K}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 2. Sigmoid routing parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSigmoidRoutingParity:
|
||||
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_weights_match_with_groups(self):
|
||||
"""Both implementations should produce identical weights with group selection."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert sm_topk == K
|
||||
assert sm_E == E
|
||||
|
||||
# Sort experts within each token to handle different topk orderings
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
|
||||
# Gather weights in sorted order for comparison
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_weights_match_no_groups(self):
|
||||
"""Both implementations match without group selection (n_group=1)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
# Sort for comparison (topk with sorted=False may differ in order)
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_bias_on_block_parity(self):
|
||||
"""minimax_m2 style: bias on block, not gate."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, bias_on_gate=False
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_scaling_factor_parity(self):
|
||||
"""routed_scaling_factor applied identically by both."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
moe_block.routed_scaling_factor = 2.5
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_no_renorm_parity(self):
|
||||
"""norm_topk_prob=False produces same results in both."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
moe_block.norm_topk_prob = False
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 3. Shared expert parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSharedExpertParity:
|
||||
"""Verify both _compute_shared_expert implementations behave identically."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def _get_both_fns(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert as scatter_compute,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import (
|
||||
_compute_shared_expert as sonic_compute,
|
||||
)
|
||||
|
||||
return scatter_compute, sonic_compute
|
||||
|
||||
def test_shared_expert_singular(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_expert=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_experts=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_shared_mlp(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_mlp=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_no_shared_expert(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
block = SimpleNamespace()
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert scatter_fn(block, hidden) is None
|
||||
assert sonic_fn(block, hidden) is None
|
||||
|
||||
def test_shared_expert_gate_only_in_scattermoe(self):
|
||||
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
|
||||
SonicMoE's patch.py handles it externally in the forward function.
|
||||
|
||||
This documents the known divergence: the scattermoe version applies
|
||||
sigmoid gating inline, while sonicmoe applies it in the forward.
|
||||
"""
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
|
||||
H = 8
|
||||
expert_out = torch.ones(4, H)
|
||||
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
|
||||
|
||||
block = SimpleNamespace(
|
||||
shared_expert=lambda x: expert_out,
|
||||
shared_expert_gate=gate_fn,
|
||||
)
|
||||
hidden = torch.randn(4, H)
|
||||
|
||||
scatter_result = scatter_fn(block, hidden)
|
||||
sonic_result = sonic_fn(block, hidden)
|
||||
|
||||
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
|
||||
expected_gated = expert_out * 0.5
|
||||
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
|
||||
|
||||
# SonicMoE does NOT apply the gate here (it does it in the forward)
|
||||
assert torch.equal(sonic_result, expert_out)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 4. Route dispatcher parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRouteDispatcherParity:
|
||||
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_route_dispatches_softmax(self):
|
||||
"""_route should use softmax when no e_score_correction_bias."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_route,
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
route_w, route_e, route_k, route_E = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.equal(route_w, direct_w)
|
||||
assert torch.equal(route_e, direct_e)
|
||||
assert route_k == direct_k
|
||||
assert route_E == direct_E
|
||||
|
||||
def test_route_dispatches_sigmoid(self):
|
||||
"""_route should use sigmoid when e_score_correction_bias is present."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_route,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
route_w, route_e, route_k, route_E = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.equal(route_w, direct_w)
|
||||
assert torch.equal(route_e, direct_e)
|
||||
assert route_k == direct_k
|
||||
assert route_E == direct_E
|
||||
367
tests/integrations/test_scattermoe_autotune_telemetry.py
Normal file
367
tests/integrations/test_scattermoe_autotune_telemetry.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""Tests for scattermoe autotune telemetry integration.
|
||||
|
||||
These tests use mocking to verify the collection and reporting logic
|
||||
without requiring Triton or CUDA.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
|
||||
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
|
||||
|
||||
|
||||
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
|
||||
"""Create a mock triton.Config-like object."""
|
||||
return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
def _make_mock_kernel(cache=None):
|
||||
"""Create a mock autotuned kernel object with a ``.cache`` dict."""
|
||||
kernel = SimpleNamespace()
|
||||
kernel.cache = cache if cache is not None else {}
|
||||
return kernel
|
||||
|
||||
|
||||
def _make_mock_lora_ops(
|
||||
fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None
|
||||
):
|
||||
"""Build a mock ``lora_ops`` module with the four kernel attributes."""
|
||||
mod = SimpleNamespace(
|
||||
_scatter2scatter_lora=_make_mock_kernel(fwd_cache),
|
||||
_scatter2scatter_lora_dX=_make_mock_kernel(dx_cache),
|
||||
_group_bwd_lora=_make_mock_kernel(bwd_cache),
|
||||
_group_bwd_lora_fused=_make_mock_kernel(fused_cache),
|
||||
)
|
||||
return mod
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestAutotuneCollector
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestAutotuneCollector:
|
||||
"""Test ``collect_autotune_configs`` with mocked kernel objects."""
|
||||
|
||||
def test_empty_cache_returns_empty_list(self):
|
||||
"""When no kernel has been autotuned yet, return ``[]``."""
|
||||
mock_lora_ops = _make_mock_lora_ops()
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
assert result == []
|
||||
|
||||
def test_populated_cache_returns_configs(self):
|
||||
"""When a cache entry exists, it appears in the output."""
|
||||
cfg = _make_mock_config(
|
||||
{"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4
|
||||
)
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 1
|
||||
entry = result[0]
|
||||
assert entry["kernel"] == "scatter2scatter_lora_fwd"
|
||||
assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024}
|
||||
assert entry["config"]["BLOCK_N"] == 128
|
||||
assert entry["config"]["BLOCK_K"] == 64
|
||||
assert entry["config"]["num_warps"] == 8
|
||||
assert entry["config"]["num_stages"] == 4
|
||||
|
||||
def test_multiple_kernels_and_keys(self):
|
||||
"""Multiple cache entries across kernels are all returned."""
|
||||
cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32})
|
||||
cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8)
|
||||
|
||||
mock_lora_ops = _make_mock_lora_ops(
|
||||
fwd_cache={(16, 256, 128): cfg_fwd},
|
||||
dx_cache={(16, 256, 128): cfg_dx},
|
||||
)
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 2
|
||||
names = {r["kernel"] for r in result}
|
||||
assert "scatter2scatter_lora_fwd" in names
|
||||
assert "scatter2scatter_lora_dX" in names
|
||||
|
||||
def test_extra_key_elements_stored(self):
|
||||
"""Dtype or other extra elements in the cache key are captured."""
|
||||
cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32})
|
||||
cache_key = (512, 1024, 256, "float16", "float16")
|
||||
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 1
|
||||
key = result[0]["key"]
|
||||
assert key["M"] == 512
|
||||
assert key["N"] == 1024
|
||||
assert key["K"] == 256
|
||||
assert key["_extra"] == ["float16", "float16"]
|
||||
|
||||
def test_no_module_in_sys_modules_returns_empty(self):
|
||||
"""If no lora_ops module is loaded, return ``[]``."""
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
# Don't inject anything — the real lora_ops isn't loaded either
|
||||
# (no triton on this machine), so _find_lora_ops_module returns None.
|
||||
result = collect_autotune_configs()
|
||||
assert result == []
|
||||
|
||||
def test_finds_module_under_hash_suffixed_name(self):
|
||||
"""Collector finds lora_ops regardless of the hash suffix."""
|
||||
cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128})
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg})
|
||||
|
||||
# Use a different hash to prove it's not hardcoded.
|
||||
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
|
||||
with patch.dict(sys.modules, {alt_name: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["config"]["BLOCK_N"] == 256
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestAutotuneReportCallback
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestAutotuneReportCallback:
|
||||
"""Test the callback fires once and sends the correct event."""
|
||||
|
||||
def test_reports_once_on_first_step(self):
|
||||
"""Callback should call ``send_event`` exactly once."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = 1
|
||||
|
||||
fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=fake_configs,
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = True
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 1
|
||||
|
||||
call_kwargs = mock_tm.send_event.call_args[1]
|
||||
assert call_kwargs["event_type"] == "scattermoe-autotune"
|
||||
assert call_kwargs["properties"]["kernel_count"] == 1
|
||||
|
||||
# Second call should NOT send again.
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 1
|
||||
|
||||
def test_retries_until_step_5_then_gives_up(self):
|
||||
"""If no configs found by step 5, stop retrying."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
|
||||
with patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=[],
|
||||
):
|
||||
for step in range(1, 7):
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = step
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
|
||||
assert cb._reported is True
|
||||
|
||||
def test_reports_on_retry_when_data_arrives(self):
|
||||
"""If step 1 has no data but step 2 does, report at step 2."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _collector():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return []
|
||||
return fake_configs
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
side_effect=_collector,
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = True
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
# Step 1 — empty, no report
|
||||
s1 = MagicMock()
|
||||
s1.global_step = 1
|
||||
cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 0
|
||||
|
||||
# Step 2 — data arrives, report
|
||||
s2 = MagicMock()
|
||||
s2.global_step = 2
|
||||
cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 1
|
||||
|
||||
def test_includes_gpu_info(self):
|
||||
"""Event properties should include GPU identification."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = 1
|
||||
|
||||
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
|
||||
fake_gpu = {
|
||||
"gpu_name": "NVIDIA H100",
|
||||
"gpu_compute_capability": "9.0",
|
||||
"gpu_memory_bytes": 85899345920,
|
||||
}
|
||||
|
||||
fake_smem = {"smem_capacity_bytes": 233472}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=fake_configs,
|
||||
),
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_callback._get_gpu_info",
|
||||
return_value=fake_gpu,
|
||||
),
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_callback._get_smem_capacity",
|
||||
return_value=fake_smem,
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = True
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
props = mock_tm.send_event.call_args[1]["properties"]
|
||||
assert props["gpu_name"] == "NVIDIA H100"
|
||||
assert props["gpu_compute_capability"] == "9.0"
|
||||
assert props["gpu_memory_bytes"] == 85899345920
|
||||
assert props["smem_capacity_bytes"] == 233472
|
||||
|
||||
def test_skips_send_when_telemetry_disabled(self):
|
||||
"""If telemetry is disabled, no event is sent."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = 1
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=[{"kernel": "fwd", "key": {}, "config": {}}],
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = False
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 0
|
||||
# Should still mark as reported so we don't retry.
|
||||
assert cb._reported is True
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestKernelsPluginCallbackRegistration
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestKernelsPluginCallbackRegistration:
|
||||
"""Test that ``KernelsPlugin`` registers the callback correctly."""
|
||||
|
||||
def test_scattermoe_registers_callback(self):
|
||||
"""When ``use_scattermoe=True``, plugin returns the callback."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
from axolotl.integrations.kernels.plugin import KernelsPlugin
|
||||
|
||||
plugin = KernelsPlugin()
|
||||
cfg = MagicMock()
|
||||
cfg.use_scattermoe = True
|
||||
model = MagicMock()
|
||||
|
||||
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
assert len(callbacks) == 1
|
||||
assert isinstance(callbacks[0], AutotuneReportCallback)
|
||||
|
||||
def test_no_scattermoe_no_callback(self):
|
||||
"""When ``use_scattermoe=False``, plugin returns empty list."""
|
||||
from axolotl.integrations.kernels.plugin import KernelsPlugin
|
||||
|
||||
plugin = KernelsPlugin()
|
||||
cfg = MagicMock()
|
||||
cfg.use_scattermoe = False
|
||||
model = MagicMock()
|
||||
|
||||
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
assert callbacks == []
|
||||
@@ -3,17 +3,19 @@
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Unit tests for scattermoe-lora code-review fixes.
|
||||
Unit tests for scattermoe-lora.
|
||||
|
||||
Tests cover:
|
||||
- KernelsArgs validator: disable_mlp_kernel
|
||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||
- single2scatter: non-aligned K/N dimensions
|
||||
- group_compileable: coeff=None accepted
|
||||
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
|
||||
- Routing strategy detection and sigmoid routing
|
||||
- Generic shared expert handling
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -321,3 +323,389 @@ class TestLayerReturnValues:
|
||||
assert "Router logits" not in docstring, (
|
||||
"Docstring should not mention 'Router logits' in Returns section"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 7. Routing strategy detection and sigmoid routing
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_softmax_gate(E=4, H=16, K=2):
|
||||
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
|
||||
return SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_sigmoid_gate_with_bias(E=16, H=16):
|
||||
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
|
||||
return SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
|
||||
|
||||
def _make_sigmoid_moe_block(
|
||||
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style: bias on block, not gate
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
def _skip_without_triton():
|
||||
pytest.importorskip("triton")
|
||||
|
||||
|
||||
class TestSigmoidRoutingInScatterMoE:
|
||||
"""Test _sigmoid_topk_route from layers.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
assert top_k == K
|
||||
assert num_experts == E
|
||||
|
||||
def test_weights_nonnegative(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With n_group=4, topk_group=1, experts should be from selected groups."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(
|
||||
E=16, K=2, n_group=4, topk_group=1
|
||||
)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, expert_idx, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# Each token's experts should fall within a single group (size E//n_group=4)
|
||||
for t in range(T):
|
||||
experts_t = expert_idx[t]
|
||||
groups = experts_t // (E // moe_block.n_group)
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights_1x, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
moe_block.routed_scaling_factor = 2.0
|
||||
weights_2x, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
|
||||
|
||||
def test_bias_on_gate(self):
|
||||
"""e_score_correction_bias on gate is found."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
|
||||
def test_bias_on_block(self):
|
||||
"""e_score_correction_bias on moe_block (not gate) is found."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
|
||||
def test_gate_lora_delta_applied(self):
|
||||
"""Gate LoRA delta should affect routing logits."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights_no_lora, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# Large delta should change the results
|
||||
delta = torch.randn(E, H) * 10.0
|
||||
weights_with_lora, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, delta
|
||||
)
|
||||
|
||||
assert not torch.equal(weights_no_lora, weights_with_lora)
|
||||
|
||||
def test_no_bias_does_not_crash(self):
|
||||
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
T, H, E, K = 8, 16, 8, 2
|
||||
gate = SimpleNamespace(weight=torch.randn(E, H))
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
# Without bias, scores_for_choice == sigmoid(logits) — all positive
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_missing_topk_group_defaults_to_n_group(self):
|
||||
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
T, H, E, K, n_group = 8, 16, 16, 2, 4
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
# Intentionally omit topk_group
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
# Should not raise AttributeError; defaults topk_group to n_group
|
||||
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
|
||||
|
||||
class TestRoutingStrategyDetection:
|
||||
"""Test that _route dispatches to the correct strategy."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_softmax_for_qwen_style(self):
|
||||
"""Block without e_score_correction_bias should use softmax."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
gate = _make_softmax_gate(E=4, H=16, K=2)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(8, 16)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (8, 2)
|
||||
assert experts.shape == (8, 2)
|
||||
assert top_k == 2
|
||||
assert num_experts == 4
|
||||
per_token_sums = weights.sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
|
||||
|
||||
def test_sigmoid_for_glm_style(self):
|
||||
"""Block with e_score_correction_bias on gate should use sigmoid."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_sigmoid_for_minimax_m2_style(self):
|
||||
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 8. Generic shared expert handling
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestGenericSharedExpert:
|
||||
"""Test _compute_shared_expert from layers.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_shared_expert_singular(self):
|
||||
"""shared_expert attribute (Qwen2MoE style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_expert=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
"""shared_experts attribute (DeepSeek V3 style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_experts=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_mlp(self):
|
||||
"""shared_mlp attribute (Hunyuan style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_mlp=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_expert_with_gate(self):
|
||||
"""shared_expert + shared_expert_gate applies sigmoid gating."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
H = 8
|
||||
expert_out = torch.ones(4, H)
|
||||
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
|
||||
|
||||
moe_block = SimpleNamespace(
|
||||
shared_expert=lambda x: expert_out,
|
||||
shared_expert_gate=gate_fn,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, H))
|
||||
expected = expert_out * 0.5 # sigmoid(0) = 0.5
|
||||
assert torch.allclose(result, expected, atol=1e-6)
|
||||
|
||||
def test_no_shared_expert(self):
|
||||
"""No shared expert attributes returns None."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user