update for transformers v5 for experts parameters and compose with moe kernels

This commit is contained in:
Wing Lian
2026-03-22 11:52:34 -04:00
parent 4009a2ba5f
commit 5acb1b0ade
5 changed files with 650 additions and 162 deletions

View File

@@ -1,11 +1,19 @@
"""Architecture-specific adapters for aux-loss-free MoE routing.
Each adapter discovers MoE layers for a model family and patches only the
router/gate to inject per-expert bias into expert selection while keeping
mixture weights from unbiased logits. Expert dispatch is left untouched so
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from axolotl.utils.logging import get_logger
@@ -23,9 +31,10 @@ class LayerHandle:
class BaseMoEAdapter:
"""Base adapter that discovers MoE layers and wraps their forward.
"""Base adapter that discovers MoE layers and patches their routing.
Concrete adapters should implement discovery and per-layer attribute extraction.
Concrete adapters implement discovery, attribute extraction, and
architecture-specific router patching.
"""
family: str = "generic"
@@ -33,158 +42,191 @@ class BaseMoEAdapter:
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
return False
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
def find_moe_layers(
self, model: nn.Module
) -> Iterable[nn.Module]: # pragma: no cover
return []
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
def get_top_k(self, moe_layer: nn.Module) -> int:
"""Resolve top_k from the MoE layer, checking common attribute paths."""
for attr_path in [
("top_k",),
("num_experts_per_tok",),
("gate", "top_k"),
("router", "top_k"),
]:
obj: object = moe_layer
for attr in attr_path:
obj = getattr(obj, attr, None)
if obj is None:
break
if isinstance(obj, int):
return obj
return 2
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts"))
def get_num_experts(self, moe_layer: nn.Module) -> int:
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
for attr_path in [
("num_experts",),
("num_local_experts",),
("gate", "num_experts"),
("router", "num_experts"),
("experts", "num_experts"),
]:
obj: object = moe_layer
for attr in attr_path:
obj = getattr(obj, attr, None)
if obj is None:
break
if isinstance(obj, int):
return obj
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
# Best-effort: zero router aux loss coef if present
if hasattr(model_or_layer, "router_aux_loss_coef"):
try:
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
model_or_layer.router_aux_loss_coef = 0.0
except Exception: # pragma: no cover - non-critical
pass
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
def _register_aux_buffers(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
moe_layer.register_buffer(
"_afb_bias", torch.zeros(handle.num_experts, device=device)
)
if not hasattr(moe_layer, "_afb_counts"):
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
moe_layer.register_buffer(
"_afb_counts", torch.zeros(handle.num_experts, device=device)
)
if not hasattr(moe_layer, "_afb_ema"):
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
moe_layer.register_buffer(
"_afb_ema", torch.zeros(handle.num_experts, device=device)
)
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
shim.register_layer_buffers(handle.layer_idx, moe_layer)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled."""
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
"""Attach per-layer buffers. Subclasses override to also patch routing."""
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
"""Replace the layer's forward with an aux-free gating version.
Assumes the layer exposes attributes:
- gate: linear router projecting hidden to num_experts
- num_experts: int
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
"""
if getattr(moe_layer, "_afb_patched", False):
return
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
return
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
# hidden_states: (B, T, H)
bsz, seqlen, hdim = hidden_states.shape
hs = hidden_states.view(-1, hdim)
logits = self.gate(hs)
# selection uses biased logits; weights from unbiased logits
bias = getattr(self, "_afb_bias")
top_k = int(getattr(self, "_afb_top_k", 2))
biased = logits + bias # broadcast over tokens
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1)
weights = weights.to(hs.dtype)
# accumulate counts for bias update callback
flat_idx = topk_idx.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
# dispatch tokens to experts
hs_rep = hs.repeat_interleave(top_k, dim=0)
y = torch.empty_like(hs_rep)
for eid in range(int(self.num_experts)):
mask = flat_idx == eid
if mask.any():
y[mask] = self.experts[eid](hs_rep[mask])
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
out = y.view(bsz, seqlen, hdim)
return (out, logits)
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
already replaced the block forward, meaning the routing is handled
inside the kernel forward and we should NOT patch the router."""
cls = type(moe_layer)
# SonicMoE stores the original forward when it patches a class.
if hasattr(cls, "_original_forward"):
return True
# ScatterMoE replaces via kernels library; check for the marker.
if hasattr(cls, "_kernel_forward"):
return True
return False
class MixtralAdapter(BaseMoEAdapter):
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
routing so that biased logits drive expert *selection* while unbiased
softmax scores drive mixture *weights*.
Works with transformers v5 where experts are fused 3D tensors and
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
"""
family = "mixtral"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_mixtral_forward(moe_layer, shim)
return (
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
)
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
if getattr(moe_layer, "_afb_patched", False):
return
shim_ref = shim
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and getattr(self, "jitter_noise", 0) > 0:
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
flat_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(flat_states)
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
top_k = int(getattr(self, "_afb_top_k", self.top_k))
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
routing_weights = routing_weights.to(flat_states.dtype)
flat_idx = selected_experts.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=flat_states.dtype,
device=flat_states.device,
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
if not self.uses_kernel_routing(moe_layer):
self._patch_router(moe_layer)
else:
LOG.info(
"AuxFreeMoE: kernel backend detected on %s; "
"skipping router patch (kernel routing handles bias)",
type(moe_layer).__name__,
)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx_tensor in expert_hit:
expert_idx = int(expert_idx_tensor.squeeze().item())
expert_layer = self.experts[expert_idx]
mask = expert_mask[expert_idx].squeeze(0)
idx, top_x = torch.where(mask)
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
def _patch_router(self, moe_layer: nn.Module) -> None:
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
gate = getattr(moe_layer, "gate", None)
if gate is None:
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
return
if getattr(gate, "_afb_patched", False):
return
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
# Capture reference to the MoE block for bias / counts access.
block_ref = moe_layer
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
def afb_router_forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight)
router_probs = F.softmax(router_logits.float(), dim=-1)
# Biased selection, unbiased weights
bias = block_ref._afb_bias
biased = router_probs + bias
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
router_scores = torch.gather(router_probs, 1, router_indices)
# Renormalize (Mixtral always normalizes; Qwen checks config)
if getattr(self, "norm_topk_prob", True):
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
# Accumulate counts for the bias-update callback
flat_idx = router_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=self.num_experts)
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
return router_probs, router_scores, router_indices
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
gate._afb_patched = True
moe_layer._afb_patched = True
class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
return getattr(getattr(model, "config", object()), "model_type", "") in (
"qwen3_moe",
"qwen2_moe",
)
class Qwen35MoeAdapter(MixtralAdapter):
"""Adapter for Qwen 3.5 MoE models.
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
is handled by the block forward (untouched by router-level patching).
"""
family = "qwen3_5_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in (
"qwen3_5_moe",
"qwen3_5_moe_text",
)
class BailingAdapter(BaseMoEAdapter):
@@ -207,11 +249,15 @@ class BailingAdapter(BaseMoEAdapter):
def get_num_experts(self, moe_layer: nn.Module) -> int:
if hasattr(moe_layer, "num_experts"):
return int(getattr(moe_layer, "num_experts"))
return int(moe_layer.num_experts)
cfg = getattr(moe_layer, "config", None)
return int(getattr(cfg, "num_experts"))
if cfg is None:
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
return int(cfg.num_experts)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_bailing_gate(moe_layer)
@@ -227,7 +273,7 @@ class BailingAdapter(BaseMoEAdapter):
flat = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(flat.float(), self.weight.float())
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = getattr(moe_layer, "_afb_bias")
bias = moe_layer._afb_bias
biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx)
@@ -238,12 +284,12 @@ class BailingAdapter(BaseMoEAdapter):
flat_topk = topk_idx.reshape(-1)
counts = torch.bincount(flat_topk, minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
return topk_idx, weights.to(hidden_states.dtype), logits
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
setattr(gate, "_afb_patched", True)
gate._afb_patched = True
class Llama4Adapter(BaseMoEAdapter):
@@ -257,7 +303,9 @@ class Llama4Adapter(BaseMoEAdapter):
if m.__class__.__name__ == "Llama4TextMoe":
yield m
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_llama4_router(moe_layer)
@@ -270,9 +318,13 @@ class Llama4Adapter(BaseMoEAdapter):
return
def afb_router_forward(self, hidden_states: torch.Tensor):
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
flat = (
hidden_states
if hidden_states.dim() == 2
else hidden_states.view(-1, hidden_states.shape[-1])
)
router_logits = F.linear(flat, self.weight, self.bias)
bias = getattr(moe_layer, "_afb_bias")
bias = moe_layer._afb_bias
biased_logits = router_logits + bias
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
unbiased_top = torch.gather(router_logits, 1, router_indices)
@@ -281,15 +333,17 @@ class Llama4Adapter(BaseMoEAdapter):
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
return router_scores, router_logits
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
setattr(router, "_afb_patched", True)
router._afb_patched = True
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
def discover_and_prepare_layers(
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
Returns a list of layer handles for later routing patching and updates.
@@ -313,7 +367,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
try:
top_k = adapter.get_top_k(layer)
nE = adapter.get_num_experts(layer)
except Exception:
except (AttributeError, TypeError, ValueError):
continue
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
@@ -321,5 +375,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
handles.append(handle)
idx += 1
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
LOG.info(
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing"
)
return handles

View File

@@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection.
from __future__ import annotations
from typing import Optional, Any
from typing import Any, Optional
import torch
import torch.distributed as dist
@@ -21,6 +21,7 @@ from .adapters import (
Llama4Adapter,
MixtralAdapter,
Qwen3Adapter,
Qwen35MoeAdapter,
discover_and_prepare_layers,
)
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
@@ -47,9 +48,11 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
# Iterate prepared MoE layers and apply the bias update rule.
self.shim.begin_step()
for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
if not hasattr(layer, "_afb_counts") or not hasattr(
layer, "_afb_layer_idx"
):
continue
counts = getattr(layer, "_afb_counts")
counts = layer._afb_counts
if counts is None:
continue
counts = self.shim.all_reduce_counts(counts)
@@ -57,7 +60,7 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
if layer_idx is None:
counts.zero_()
continue
bias = getattr(layer, "_afb_bias")
bias = layer._afb_bias
counts_for_update = counts.to(bias.device)
tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update
@@ -143,12 +146,16 @@ class AuxFreeMoEPlugin(BasePlugin):
return
# Be conservative — skip known native aux-free families
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
native_auxfree = getattr(
getattr(model, "config", object()), "model_type", ""
) in (
"deepseek_v3",
"glm4_moe",
)
if native_auxfree:
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
LOG.info(
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
)
return
# Build aux-free state and shim
@@ -171,6 +178,7 @@ class AuxFreeMoEPlugin(BasePlugin):
adapters: list[BaseMoEAdapter] = [
MixtralAdapter(),
Qwen3Adapter(),
Qwen35MoeAdapter(),
BailingAdapter(),
Llama4Adapter(),
]
@@ -178,7 +186,7 @@ class AuxFreeMoEPlugin(BasePlugin):
# For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0
n_experts = None
for m in model.modules():
for _m in model.modules():
n_layers += 1 # upper bound — we will re-use bias slots sparsely
device = next(model.parameters(), torch.tensor(0)).device
if n_layers <= 0:
@@ -186,7 +194,9 @@ class AuxFreeMoEPlugin(BasePlugin):
if n_experts is None:
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
state = AuxFreeState(
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
)
ep_size = getattr(cfg, "expert_parallel_size", None)
ep_group = None
if sync_group == "ep":
@@ -207,11 +217,15 @@ class AuxFreeMoEPlugin(BasePlugin):
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
if not dist.is_available() or not dist.is_initialized():
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
LOG.warning(
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
)
return None
ep_size = getattr(cfg, "expert_parallel_size", None)
if not ep_size or ep_size <= 1:
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
)
return None
world = dist.get_world_size()
if world % ep_size != 0:

View File

@@ -240,7 +240,16 @@ def _softmax_topk_route(
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# Aux-free bias: biased selection, unbiased weights
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = routing_weights + afb_bias
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
routing_weights = routing_weights.gather(1, selected_experts)
_accumulate_afb_counts(moe_block, selected_experts)
else:
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
else:
scores_for_choice = router_probs
# Aux-free bias: stacks on top of e_score_correction_bias for selection
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = scores_for_choice + afb_bias
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
)
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
"""Accumulate per-expert token counts for aux-free bias updates."""
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return
flat_idx = topk_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
afb_counts.add_(counts.to(afb_counts.dtype))
# =============================================================================
# Shared expert helpers
# =============================================================================

View File

@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
as buffers. The routing functions transparently inject the bias into expert
*selection* (biased topk) while keeping mixture *weights* from unbiased
scores, then accumulate per-expert token counts for the post-step bias update.
"""
import torch
@@ -101,17 +107,25 @@ def softmax_topk_routing(
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Aux-free bias: biased selection, unbiased weights
afb_bias = getattr(moe_block, "_afb_bias", None)
scores_for_choice = router_probs
if afb_bias is not None:
scores_for_choice = router_probs + afb_bias
# Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
# When aux-free bias is active, gather unbiased weights and accumulate counts
if afb_bias is not None:
top_values = router_probs.gather(1, top_indices)
_accumulate_afb_counts(moe_block, top_indices)
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Aux-free bias: inject before group selection / topk
afb_bias = getattr(moe_block, "_afb_bias", None)
scores_for_choice = router_probs
if afb_bias is not None:
scores_for_choice = router_probs + afb_bias
# Group selection: pick top groups, mask the rest
if n_group > 1:
@@ -164,6 +182,10 @@ def softmax_group_topk_routing(
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
@@ -233,6 +255,11 @@ def sigmoid_topk_routing(
)
scores_for_choice = router_probs + e_score_correction_bias
# Aux-free bias: stacks on top of e_score_correction_bias for selection
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = scores_for_choice + afb_bias
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
if n_group > 1:
group_scores = (
@@ -256,6 +283,10 @@ def sigmoid_topk_routing(
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Optional renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
@@ -276,3 +307,19 @@ def sigmoid_topk_routing(
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
"""Accumulate per-expert token counts for the aux-free bias update.
Called when ``moe_block._afb_bias`` is present (registered by the
``aux_free_router`` plugin). The counts are later consumed by the
``MoeAuxFreeBiasUpdateCallback`` at each training step.
"""
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return
num_experts = afb_counts.numel()
flat_idx = topk_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=num_experts)
afb_counts.add_(counts.to(afb_counts.dtype))

View File

@@ -2,14 +2,13 @@ import os
import sys
import tempfile
import unittest
from importlib import util as importlib_util
from pathlib import Path
from types import SimpleNamespace
import torch
import torch.distributed as dist
import torch.nn as nn
from importlib import util as importlib_util
from pathlib import Path
from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
@@ -96,16 +95,21 @@ def _build_bailing_model():
def _build_llama4_model():
from transformers import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
config = Llama4TextConfig(
# Build config without __post_init__ validation (works around a
# huggingface_hub strict-dataclass type mismatch for layer_types).
config = object.__new__(__import__("transformers").Llama4TextConfig)
config.__dict__.update(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_attention_heads=2,
num_key_value_heads=2,
num_experts_per_tok=2,
num_hidden_layers=2,
hidden_act="silu",
layer_types=None,
)
layer = Llama4TextMoe(config)
@@ -148,6 +152,38 @@ def _build_mixtral_model():
return DummyModel(layer), layer
def _build_qwen35_moe_model():
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeSparseMoeBlock,
)
config = Qwen3_5MoeTextConfig(
hidden_size=16,
moe_intermediate_size=32,
shared_expert_intermediate_size=32,
num_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
num_hidden_layers=2,
)
layer = Qwen3_5MoeSparseMoeBlock(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="qwen3_5_moe")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
if args is None:
args = SimpleNamespace(logging_steps=1)
@@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
self.assertFalse(
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
)
def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model()
@@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_bias_warmup_respected(self):
model, block = _build_bailing_model()
@@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase):
# Warmup steps should leave bias untouched.
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
# Third step exceeds warmup -> bias should update.
_step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_respects_native_forward(self):
def test_mixtral_adapter_patches_router_not_forward(self):
"""Verify that aux-free patches the router (gate) only, and the
v5 block forward signature (single tensor return) is preserved."""
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
baseline_out, baseline_logits = layer(hidden.clone())
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
patched_out, patched_logits = layer(hidden.clone())
self.assertTrue(torch.allclose(baseline_out, patched_out))
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
# Gate should be patched, not the block forward
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# v5 block forward returns a single tensor (not a tuple with logits)
hidden = torch.randn(2, 3, layer.config.hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
def test_mixtral_adapter_bias_affects_selection(self):
"""When bias is large for one expert, it should be selected more often."""
model, layer = _build_mixtral_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Set a large bias for expert 0 to force its selection
layer._afb_bias.zero_()
layer._afb_bias[0] = 10.0
hidden = torch.randn(2, 8, layer.config.hidden_size)
num_tokens = 2 * 8 # batch * seq
layer(hidden)
# With top_k=2, expert 0 should appear in every token's selection
# (once per token = num_tokens counts, not num_tokens * top_k)
counts = layer._afb_counts.clone()
self.assertEqual(
int(counts[0].item()),
num_tokens,
msg="Expert 0 should be selected for every token when heavily biased",
)
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
output includes shared expert contribution."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Gate should be patched
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# Shared expert should be unmodified
self.assertTrue(hasattr(layer, "shared_expert"))
self.assertTrue(hasattr(layer, "shared_expert_gate"))
# Forward should return a single tensor (shared + routed)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 3, hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
def test_qwen35_moe_adapter_bias_updates(self):
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 4, hidden_size)
layer(hidden)
# Bias should start at zero
self.assertTrue(
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
)
_run_callback(plugin, cfg)
# After callback: counts reset, EMA updated, bias updated
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_qwen35_moe_adapter_model_type_matching(self):
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
adapter = Qwen35MoeAdapter()
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
model_text = SimpleNamespace(
config=SimpleNamespace(model_type="qwen3_5_moe_text")
)
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
self.assertTrue(adapter.matches(model_moe))
self.assertTrue(adapter.matches(model_text))
self.assertFalse(adapter.matches(model_other))
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
@@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
dist.init_process_group(
backend="gloo", init_method=init_method, world_size=1, rank=0
)
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
@@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
@@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase):
self.assertIn("moe_afb/l0_load_max", telemetry)
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
def test_get_num_experts_v5_attribute_paths(self):
"""Verify get_num_experts works with v5 attribute layout where
num_experts is on gate/experts sub-modules, not the block."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
block = SimpleNamespace(
gate=SimpleNamespace(num_experts=8),
experts=SimpleNamespace(num_experts=8),
)
self.assertEqual(adapter.get_num_experts(block), 8)
# Also works when num_experts is directly on block
block2 = SimpleNamespace(num_experts=4)
self.assertEqual(adapter.get_num_experts(block2), 4)
class TestAuxFreeKernelComposition(unittest.TestCase):
"""Tests that aux-free bias composes correctly with kernel routing."""
def test_sonicmoe_softmax_routing_with_afb_bias(self):
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
# Build a mock MoE block with gate attributes
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, hidden_dim)
# Baseline: no bias
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
hidden, moe_block
)
self.assertEqual(scores_base.shape[0], T * top_k)
# Now register aux-free buffers and set heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
hidden, moe_block
)
# Expert 0 should be selected for every token
self.assertTrue(
(exp_biased == 0).any(),
"Expert 0 should appear in selections when heavily biased",
)
# Counts should have been accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
# Total counts should equal T * top_k
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_sonicmoe_routing_without_bias_unchanged(self):
"""Without _afb_bias, routing should produce identical results."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(4, hidden_dim)
# Without _afb_bias attribute
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
# With _afb_bias = zeros (should be equivalent)
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_counts = torch.zeros(num_experts)
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
torch.testing.assert_close(scores1, scores2)
torch.testing.assert_close(exp1, exp2)
@unittest.skipUnless(
importlib_util.find_spec("triton") is not None,
"triton not installed (required by scattermoe)",
)
def test_scattermoe_softmax_routing_with_afb_bias(self):
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
gate_weight = torch.randn(num_experts, hidden_dim)
base_gate = SimpleNamespace(
top_k=top_k,
num_experts=num_experts,
norm_topk_prob=True,
weight=gate_weight,
)
moe_block = SimpleNamespace()
hidden = torch.randn(T, hidden_dim)
# Baseline without bias
w_base, e_base, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# With heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
w_biased, e_biased, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# Expert 0 should appear in all selections
self.assertTrue((e_biased == 0).any())
# Counts accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_kernel_routing_skips_router_patch(self):
"""When a kernel backend has patched the block class, the adapter
should skip patching the router (buffers are still registered)."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Create a mock layer whose class has _original_forward (SonicMoE marker)
class PatchedBlock(nn.Module):
_original_forward = True # SonicMoE marker
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16) # placeholder
layer = PatchedBlock()
self.assertTrue(adapter.uses_kernel_routing(layer))
# Gate should NOT be patched (kernel handles routing)
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
def test_adapter_buffers_registered_even_with_kernel(self):
"""Even when kernel routing is active, aux-free buffers must be
registered on the MoE block so the kernel routing can find them."""
from axolotl.integrations.aux_free_router.adapters import (
LayerHandle,
MixtralAdapter,
)
from axolotl.integrations.aux_free_router.core import (
AuxFreeConfig,
AuxFreeShim,
AuxFreeState,
)
class PatchedBlock(nn.Module):
_original_forward = True
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16)
layer = PatchedBlock()
adapter = MixtralAdapter()
cfg = AuxFreeConfig()
state = AuxFreeState(
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
)
shim = AuxFreeShim(state=state)
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
adapter.prepare(layer, handle, shim)
# Buffers should be registered for kernel routing to use
self.assertTrue(hasattr(layer, "_afb_bias"))
self.assertTrue(hasattr(layer, "_afb_counts"))
self.assertTrue(hasattr(layer, "_afb_ema"))
# But gate should NOT be patched
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
if __name__ == "__main__":
unittest.main()