consolidate behavioud of routing in scattermoe kernels (#3475)

* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
This commit is contained in:
Wing Lian
2026-03-16 23:47:40 -04:00
committed by GitHub
parent 830e9f7eaf
commit 8f3fb517b3
8 changed files with 1988 additions and 35 deletions

View File

@@ -0,0 +1,120 @@
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
import logging
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
LOG = logging.getLogger(__name__)
# Give up looking for autotune data after this many training steps.
_MAX_POLL_STEP = 5
def _get_gpu_info() -> dict:
"""Return basic GPU identification for the current device."""
if not torch.cuda.is_available():
return {}
try:
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
return {
"gpu_name": props.name,
"gpu_compute_capability": f"{props.major}.{props.minor}",
"gpu_memory_bytes": props.total_memory,
}
except Exception: # pylint: disable=broad-exception-caught
return {}
def _get_smem_capacity() -> dict:
"""Return shared memory capacity from the runtime lora_ops module."""
try:
from axolotl.integrations.kernels.autotune_collector import (
_find_lora_ops_module,
)
lora_ops = _find_lora_ops_module()
if lora_ops is None:
return {}
fn = getattr(lora_ops, "_get_smem_capacity", None)
if fn is None:
return {}
return {"smem_capacity_bytes": fn()}
except Exception: # pylint: disable=broad-exception-caught
return {}
class AutotuneReportCallback(TrainerCallback):
"""Reports Triton kernel autotune selections via telemetry.
Fires **once** after the first training step completes (step 1), at
which point the forward and backward passes have both run and the
autotuned kernels have populated their caches. If for some reason
the caches are still empty (e.g. the kernel was never invoked), the
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
then stops polling.
After reporting (or giving up) every subsequent ``on_step_end``
call short-circuits on the ``_reported`` flag — zero hot-path cost.
"""
def __init__(self):
self._reported = False
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._reported:
return
# Lazy import — Triton / scattermoe kernels may not be installed.
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
configs = collect_autotune_configs()
if not configs:
if state.global_step >= _MAX_POLL_STEP:
LOG.debug(
"No autotune data found after %d steps; giving up.",
state.global_step,
)
self._reported = True
return
self._reported = True
from axolotl.telemetry.manager import TelemetryManager
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return
properties = {
"kernel_count": len(configs),
"kernels": configs,
}
properties.update(_get_gpu_info())
properties.update(_get_smem_capacity())
telemetry_manager.send_event(
event_type="scattermoe-autotune",
properties=properties,
)
LOG.info(
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
len(configs),
)

View File

@@ -0,0 +1,114 @@
"""Collect Triton autotune results from scattermoe-lora kernels.
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
decorated kernel objects and returns structured dicts describing the selected
configurations. It has **no** telemetry dependency — callers decide what to
do with the data.
"""
import logging
import sys
from types import ModuleType
from typing import Any
LOG = logging.getLogger(__name__)
# (human-readable name, attribute on the lora_ops module)
_KERNEL_REGISTRY: list[tuple[str, str]] = [
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
("group_bwd_lora", "_group_bwd_lora"),
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
]
# The autotune key declared on every kernel: key=["M", "N", "K"]
_KEY_NAMES: list[str] = ["M", "N", "K"]
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
"""Turn the autotune cache key tuple into a labelled dict.
Triton builds the cache key from the values of the declared ``key``
args (``M``, ``N``, ``K``) followed by dtype signature elements.
We label the first three and store the rest under ``_extra``.
"""
result: dict[str, Any] = {}
for i, name in enumerate(_KEY_NAMES):
if i < len(key_tuple):
result[name] = key_tuple[i]
if len(key_tuple) > len(_KEY_NAMES):
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
return result
def _find_lora_ops_module() -> ModuleType | None:
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
The HF ``kernels`` package loads ``scattermoe_lora`` via
``import_from_path`` which registers it in ``sys.modules`` under a
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
import (``from axolotl.integrations.kernels...``) would create a
*separate* module instance whose kernel objects have empty
``.cache`` dicts because autotuning ran on the runtime copy.
We search ``sys.modules`` for any module whose name contains
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
attribute — that is the runtime copy with populated caches.
"""
for name, module in sys.modules.items():
if (
module is not None
and "lora_ops" in name
and hasattr(module, "_scatter2scatter_lora")
):
return module
return None
def collect_autotune_configs() -> list[dict[str, Any]]:
"""Read autotune caches from the four scattermoe-lora kernels.
Returns a (possibly empty) list of dicts, each containing:
* ``kernel`` human-readable kernel name
* ``key`` dict with the ``M``/``N``/``K`` problem dimensions
* ``config`` dict with the selected tile sizes, ``num_warps``,
and ``num_stages``
Returns ``[]`` if the kernel module cannot be found or if no
autotune cache entries exist yet.
"""
lora_ops = _find_lora_ops_module()
if lora_ops is None:
LOG.debug(
"lora_ops module not found in sys.modules; skipping autotune collection"
)
return []
results: list[dict[str, Any]] = []
for friendly_name, attr_name in _KERNEL_REGISTRY:
kernel_fn = getattr(lora_ops, attr_name, None)
if kernel_fn is None:
continue
cache = getattr(kernel_fn, "cache", None)
if not cache:
continue
for key_tuple, config in cache.items():
config_dict = dict(config.kwargs)
config_dict["num_warps"] = config.num_warps
config_dict["num_stages"] = config.num_stages
if getattr(config, "num_ctas", None) is not None:
config_dict["num_ctas"] = config.num_ctas
results.append(
{
"kernel": friendly_name,
"key": _parse_key_tuple(key_tuple),
"config": config_dict,
}
)
return results

View File

@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
return base_experts, gup_lora, down_lora
# =============================================================================
# Routing helpers
# =============================================================================
def _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return routing_weights, selected_experts, top_k, num_experts
def _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
Supports:
- ``e_score_correction_bias`` on gate or moe_block
- Group-based expert selection when ``n_group > 1``
- ``routed_scaling_factor`` applied to final weights
- Final weights gathered from original sigmoid probs (not bias-corrected)
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states.float(), gate_weight.float())
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = router_logits.sigmoid() # [T, E]
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is not None:
scores_for_choice = router_probs + e_score_correction_bias
else:
scores_for_choice = router_probs
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, num_experts // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
topk_group = getattr(moe_block, "topk_group", n_group)
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, n_group, num_experts // n_group)
.reshape(-1, num_experts)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_indices, top_k, num_experts
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
"""Dispatch to the correct routing strategy based on block attributes.
Detects sigmoid routing by the presence of ``e_score_correction_bias``
on either the gate or the moe_block.
"""
has_sigmoid = (
getattr(base_gate, "e_score_correction_bias", None) is not None
or getattr(moe_block, "e_score_correction_bias", None) is not None
)
if has_sigmoid:
return _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
return _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
# =============================================================================
# Shared expert helpers
# =============================================================================
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
peft wraps individual linear layers inside the shared expert with
standard LoRA — calling forward() handles this transparently.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is None:
return None
shared_expert_output = shared_expert(hidden_states_flat)
# Optional sigmoid gate (Qwen2MoE pattern).
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
if shared_expert_gate is not None:
shared_expert_output = (
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
)
return shared_expert_output
# =============================================================================
# Layer classes
# =============================================================================
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
ScatterMoE-accelerated forward pass for HF MoEs.
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
method replaces the original SparseMoeBlock.forward.
Supports both full-parameter training and LoRA fine-tuning:
Supports:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert(s): Optional shared expert
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights, selected_experts, top_k, num_experts = _route(
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(

View File

@@ -110,6 +110,16 @@ class KernelsPlugin(BasePlugin):
}
)
def add_callbacks_pre_trainer(self, cfg, model):
callbacks = []
if cfg.use_scattermoe:
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
callbacks.append(AutotuneReportCallback())
return callbacks
def _kernelize_model(self, model_type: str):
from kernels import replace_kernel_forward_from_hub