consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
This commit is contained in:
120
src/axolotl/integrations/kernels/autotune_callback.py
Normal file
120
src/axolotl/integrations/kernels/autotune_callback.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# Give up looking for autotune data after this many training steps.
|
||||
_MAX_POLL_STEP = 5
|
||||
|
||||
|
||||
def _get_gpu_info() -> dict:
|
||||
"""Return basic GPU identification for the current device."""
|
||||
if not torch.cuda.is_available():
|
||||
return {}
|
||||
try:
|
||||
idx = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(idx)
|
||||
return {
|
||||
"gpu_name": props.name,
|
||||
"gpu_compute_capability": f"{props.major}.{props.minor}",
|
||||
"gpu_memory_bytes": props.total_memory,
|
||||
}
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return {}
|
||||
|
||||
|
||||
def _get_smem_capacity() -> dict:
|
||||
"""Return shared memory capacity from the runtime lora_ops module."""
|
||||
try:
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
_find_lora_ops_module,
|
||||
)
|
||||
|
||||
lora_ops = _find_lora_ops_module()
|
||||
if lora_ops is None:
|
||||
return {}
|
||||
fn = getattr(lora_ops, "_get_smem_capacity", None)
|
||||
if fn is None:
|
||||
return {}
|
||||
return {"smem_capacity_bytes": fn()}
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return {}
|
||||
|
||||
|
||||
class AutotuneReportCallback(TrainerCallback):
|
||||
"""Reports Triton kernel autotune selections via telemetry.
|
||||
|
||||
Fires **once** after the first training step completes (step 1), at
|
||||
which point the forward and backward passes have both run and the
|
||||
autotuned kernels have populated their caches. If for some reason
|
||||
the caches are still empty (e.g. the kernel was never invoked), the
|
||||
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
|
||||
then stops polling.
|
||||
|
||||
After reporting (or giving up) every subsequent ``on_step_end``
|
||||
call short-circuits on the ``_reported`` flag — zero hot-path cost.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._reported = False
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if self._reported:
|
||||
return
|
||||
|
||||
# Lazy import — Triton / scattermoe kernels may not be installed.
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
configs = collect_autotune_configs()
|
||||
|
||||
if not configs:
|
||||
if state.global_step >= _MAX_POLL_STEP:
|
||||
LOG.debug(
|
||||
"No autotune data found after %d steps; giving up.",
|
||||
state.global_step,
|
||||
)
|
||||
self._reported = True
|
||||
return
|
||||
|
||||
self._reported = True
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if not telemetry_manager.enabled:
|
||||
return
|
||||
|
||||
properties = {
|
||||
"kernel_count": len(configs),
|
||||
"kernels": configs,
|
||||
}
|
||||
properties.update(_get_gpu_info())
|
||||
properties.update(_get_smem_capacity())
|
||||
|
||||
telemetry_manager.send_event(
|
||||
event_type="scattermoe-autotune",
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
|
||||
len(configs),
|
||||
)
|
||||
114
src/axolotl/integrations/kernels/autotune_collector.py
Normal file
114
src/axolotl/integrations/kernels/autotune_collector.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Collect Triton autotune results from scattermoe-lora kernels.
|
||||
|
||||
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
|
||||
decorated kernel objects and returns structured dicts describing the selected
|
||||
configurations. It has **no** telemetry dependency — callers decide what to
|
||||
do with the data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# (human-readable name, attribute on the lora_ops module)
|
||||
_KERNEL_REGISTRY: list[tuple[str, str]] = [
|
||||
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
|
||||
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
|
||||
("group_bwd_lora", "_group_bwd_lora"),
|
||||
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
|
||||
]
|
||||
|
||||
# The autotune key declared on every kernel: key=["M", "N", "K"]
|
||||
_KEY_NAMES: list[str] = ["M", "N", "K"]
|
||||
|
||||
|
||||
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
|
||||
"""Turn the autotune cache key tuple into a labelled dict.
|
||||
|
||||
Triton builds the cache key from the values of the declared ``key``
|
||||
args (``M``, ``N``, ``K``) followed by dtype signature elements.
|
||||
We label the first three and store the rest under ``_extra``.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for i, name in enumerate(_KEY_NAMES):
|
||||
if i < len(key_tuple):
|
||||
result[name] = key_tuple[i]
|
||||
if len(key_tuple) > len(_KEY_NAMES):
|
||||
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
|
||||
return result
|
||||
|
||||
|
||||
def _find_lora_ops_module() -> ModuleType | None:
|
||||
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
|
||||
|
||||
The HF ``kernels`` package loads ``scattermoe_lora`` via
|
||||
``import_from_path`` which registers it in ``sys.modules`` under a
|
||||
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
|
||||
import (``from axolotl.integrations.kernels...``) would create a
|
||||
*separate* module instance whose kernel objects have empty
|
||||
``.cache`` dicts because autotuning ran on the runtime copy.
|
||||
|
||||
We search ``sys.modules`` for any module whose name contains
|
||||
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
|
||||
attribute — that is the runtime copy with populated caches.
|
||||
"""
|
||||
for name, module in sys.modules.items():
|
||||
if (
|
||||
module is not None
|
||||
and "lora_ops" in name
|
||||
and hasattr(module, "_scatter2scatter_lora")
|
||||
):
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def collect_autotune_configs() -> list[dict[str, Any]]:
|
||||
"""Read autotune caches from the four scattermoe-lora kernels.
|
||||
|
||||
Returns a (possibly empty) list of dicts, each containing:
|
||||
|
||||
* ``kernel`` – human-readable kernel name
|
||||
* ``key`` – dict with the ``M``/``N``/``K`` problem dimensions
|
||||
* ``config`` – dict with the selected tile sizes, ``num_warps``,
|
||||
and ``num_stages``
|
||||
|
||||
Returns ``[]`` if the kernel module cannot be found or if no
|
||||
autotune cache entries exist yet.
|
||||
"""
|
||||
lora_ops = _find_lora_ops_module()
|
||||
if lora_ops is None:
|
||||
LOG.debug(
|
||||
"lora_ops module not found in sys.modules; skipping autotune collection"
|
||||
)
|
||||
return []
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for friendly_name, attr_name in _KERNEL_REGISTRY:
|
||||
kernel_fn = getattr(lora_ops, attr_name, None)
|
||||
if kernel_fn is None:
|
||||
continue
|
||||
|
||||
cache = getattr(kernel_fn, "cache", None)
|
||||
if not cache:
|
||||
continue
|
||||
|
||||
for key_tuple, config in cache.items():
|
||||
config_dict = dict(config.kwargs)
|
||||
config_dict["num_warps"] = config.num_warps
|
||||
config_dict["num_stages"] = config.num_stages
|
||||
if getattr(config, "num_ctas", None) is not None:
|
||||
config_dict["num_ctas"] = config.num_ctas
|
||||
|
||||
results.append(
|
||||
{
|
||||
"kernel": friendly_name,
|
||||
"key": _parse_key_tuple(key_tuple),
|
||||
"config": config_dict,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
|
||||
return base_experts, gup_lora, down_lora
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Routing helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _softmax_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
):
|
||||
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
|
||||
|
||||
Returns:
|
||||
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
|
||||
"""
|
||||
router_logits = F.linear(hidden_states, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
|
||||
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return routing_weights, selected_experts, top_k, num_experts
|
||||
|
||||
|
||||
def _sigmoid_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
):
|
||||
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
|
||||
|
||||
Supports:
|
||||
- ``e_score_correction_bias`` on gate or moe_block
|
||||
- Group-based expert selection when ``n_group > 1``
|
||||
- ``routed_scaling_factor`` applied to final weights
|
||||
- Final weights gathered from original sigmoid probs (not bias-corrected)
|
||||
|
||||
Returns:
|
||||
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
|
||||
"""
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float())
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
|
||||
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
|
||||
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is not None:
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
else:
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, num_experts // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [T, n_group]
|
||||
topk_group = getattr(moe_block, "topk_group", n_group)
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(-1, n_group, num_experts // n_group)
|
||||
.reshape(-1, num_experts)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
|
||||
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
if getattr(moe_block, "norm_topk_prob", True):
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_indices, top_k, num_experts
|
||||
|
||||
|
||||
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
||||
"""Dispatch to the correct routing strategy based on block attributes.
|
||||
|
||||
Detects sigmoid routing by the presence of ``e_score_correction_bias``
|
||||
on either the gate or the moe_block.
|
||||
"""
|
||||
has_sigmoid = (
|
||||
getattr(base_gate, "e_score_correction_bias", None) is not None
|
||||
or getattr(moe_block, "e_score_correction_bias", None) is not None
|
||||
)
|
||||
if has_sigmoid:
|
||||
return _sigmoid_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
)
|
||||
return _softmax_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared expert helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _compute_shared_expert(moe_block, hidden_states_flat):
|
||||
"""Compute shared expert output if the block has one.
|
||||
|
||||
Handles singular (qwen2_moe: ``shared_expert``), plural
|
||||
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
|
||||
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
|
||||
|
||||
peft wraps individual linear layers inside the shared expert with
|
||||
standard LoRA — calling forward() handles this transparently.
|
||||
"""
|
||||
shared_expert = (
|
||||
getattr(moe_block, "shared_expert", None)
|
||||
or getattr(moe_block, "shared_experts", None)
|
||||
or getattr(moe_block, "shared_mlp", None)
|
||||
)
|
||||
if shared_expert is None:
|
||||
return None
|
||||
|
||||
shared_expert_output = shared_expert(hidden_states_flat)
|
||||
|
||||
# Optional sigmoid gate (Qwen2MoE pattern).
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
|
||||
if shared_expert_gate is not None:
|
||||
shared_expert_output = (
|
||||
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
|
||||
)
|
||||
|
||||
return shared_expert_output
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Layer classes
|
||||
# =============================================================================
|
||||
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
|
||||
|
||||
class HFScatterMoEGatedMLP(nn.Module):
|
||||
"""
|
||||
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
|
||||
ScatterMoE-accelerated forward pass for HF MoEs.
|
||||
|
||||
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
|
||||
method replaces the original ``OlmoeSparseMoeBlock.forward``.
|
||||
method replaces the original SparseMoeBlock.forward.
|
||||
|
||||
Supports both full-parameter training and LoRA fine-tuning:
|
||||
Supports:
|
||||
|
||||
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
|
||||
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
|
||||
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
self: The MoeSparseMoeBlock module containing:
|
||||
- self.gate: Router (or peft ParamWrapper wrapping it)
|
||||
- self.experts: Experts module (or peft ParamWrapper chain)
|
||||
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
|
||||
- self.shared_expert(s): Optional shared expert
|
||||
- self.shared_expert_gate: Optional shared expert gate
|
||||
layer_input: Input tensor [batch_size, seq_len, hidden_size]
|
||||
|
||||
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# ====================================================================
|
||||
# peft wraps individual linear layers inside shared_expert with
|
||||
# standard LoRA — calling forward() handles this transparently.
|
||||
if hasattr(self, "shared_expert") and self.shared_expert is not None:
|
||||
shared_expert_output = self.shared_expert(hidden_states_flat)
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states_flat)
|
||||
)
|
||||
shared_expert_output = shared_expert_output * shared_expert_gate_output
|
||||
else:
|
||||
shared_expert_output = None
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Router Computation (with optional gate LoRA)
|
||||
# ====================================================================
|
||||
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
|
||||
router_logits = F.linear(hidden_states_flat, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states_flat, gate_lora_delta
|
||||
)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if base_gate.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights, selected_experts, top_k, num_experts = _route(
|
||||
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
|
||||
)
|
||||
routing_weights = routing_weights.to(hidden_states_flat.dtype)
|
||||
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
|
||||
@@ -110,6 +110,16 @@ class KernelsPlugin(BasePlugin):
|
||||
}
|
||||
)
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
callbacks = []
|
||||
if cfg.use_scattermoe:
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
callbacks.append(AutotuneReportCallback())
|
||||
return callbacks
|
||||
|
||||
def _kernelize_model(self, model_type: str):
|
||||
from kernels import replace_kernel_forward_from_hub
|
||||
|
||||
|
||||
Reference in New Issue
Block a user