update for transformers v5 for experts parameters and compose with moe kernels

This commit is contained in:
Wing Lian
2026-03-22 11:52:34 -04:00
parent 4009a2ba5f
commit 5acb1b0ade
5 changed files with 650 additions and 162 deletions

View File

@@ -1,11 +1,19 @@
"""Architecture-specific adapters for aux-loss-free MoE routing.
Each adapter discovers MoE layers for a model family and patches only the
router/gate to inject per-expert bias into expert selection while keeping
mixture weights from unbiased logits. Expert dispatch is left untouched so
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
"""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional from typing import Iterable, Optional
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -23,9 +31,10 @@ class LayerHandle:
class BaseMoEAdapter: class BaseMoEAdapter:
"""Base adapter that discovers MoE layers and wraps their forward. """Base adapter that discovers MoE layers and patches their routing.
Concrete adapters should implement discovery and per-layer attribute extraction. Concrete adapters implement discovery, attribute extraction, and
architecture-specific router patching.
""" """
family: str = "generic" family: str = "generic"
@@ -33,158 +42,191 @@ class BaseMoEAdapter:
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
return False return False
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover def find_moe_layers(
self, model: nn.Module
) -> Iterable[nn.Module]: # pragma: no cover
return [] return []
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover def get_top_k(self, moe_layer: nn.Module) -> int:
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2))) """Resolve top_k from the MoE layer, checking common attribute paths."""
for attr_path in [
("top_k",),
("num_experts_per_tok",),
("gate", "top_k"),
("router", "top_k"),
]:
obj: object = moe_layer
for attr in attr_path:
obj = getattr(obj, attr, None)
if obj is None:
break
if isinstance(obj, int):
return obj
return 2
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover def get_num_experts(self, moe_layer: nn.Module) -> int:
return int(getattr(moe_layer, "num_experts")) """Resolve num_experts from the MoE layer, checking common attribute paths."""
for attr_path in [
("num_experts",),
("num_local_experts",),
("gate", "num_experts"),
("router", "num_experts"),
("experts", "num_experts"),
]:
obj: object = moe_layer
for attr in attr_path:
obj = getattr(obj, attr, None)
if obj is None:
break
if isinstance(obj, int):
return obj
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
def disable_aux_loss(self, model_or_layer: nn.Module) -> None: def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
# Best-effort: zero router aux loss coef if present # Best-effort: zero router aux loss coef if present
if hasattr(model_or_layer, "router_aux_loss_coef"): if hasattr(model_or_layer, "router_aux_loss_coef"):
try: try:
setattr(model_or_layer, "router_aux_loss_coef", 0.0) model_or_layer.router_aux_loss_coef = 0.0
except Exception: # pragma: no cover - non-critical except Exception: # pragma: no cover - non-critical
pass pass
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: def _register_aux_buffers(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"): if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device)) moe_layer.register_buffer(
"_afb_bias", torch.zeros(handle.num_experts, device=device)
)
if not hasattr(moe_layer, "_afb_counts"): if not hasattr(moe_layer, "_afb_counts"):
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device)) moe_layer.register_buffer(
"_afb_counts", torch.zeros(handle.num_experts, device=device)
)
if not hasattr(moe_layer, "_afb_ema"): if not hasattr(moe_layer, "_afb_ema"):
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device)) moe_layer.register_buffer(
"_afb_ema", torch.zeros(handle.num_experts, device=device)
)
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined] moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined] moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
shim.register_layer_buffers(handle.layer_idx, moe_layer) shim.register_layer_buffers(handle.layer_idx, moe_layer)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: def prepare(
"""Attach per-layer buffers and mark as aux-free enabled.""" self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
"""Attach per-layer buffers. Subclasses override to also patch routing."""
self._register_aux_buffers(moe_layer, handle, shim) self._register_aux_buffers(moe_layer, handle, shim)
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None: def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
"""Replace the layer's forward with an aux-free gating version. """Return True when a kernel backend (SonicMoE / ScatterMoE) has
already replaced the block forward, meaning the routing is handled
Assumes the layer exposes attributes: inside the kernel forward and we should NOT patch the router."""
- gate: linear router projecting hidden to num_experts cls = type(moe_layer)
- num_experts: int # SonicMoE stores the original forward when it patches a class.
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H) if hasattr(cls, "_original_forward"):
""" return True
if getattr(moe_layer, "_afb_patched", False): # ScatterMoE replaces via kernels library; check for the marker.
return if hasattr(cls, "_kernel_forward"):
return True
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"): return False
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
return
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
# hidden_states: (B, T, H)
bsz, seqlen, hdim = hidden_states.shape
hs = hidden_states.view(-1, hdim)
logits = self.gate(hs)
# selection uses biased logits; weights from unbiased logits
bias = getattr(self, "_afb_bias")
top_k = int(getattr(self, "_afb_top_k", 2))
biased = logits + bias # broadcast over tokens
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1)
weights = weights.to(hs.dtype)
# accumulate counts for bias update callback
flat_idx = topk_idx.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
# dispatch tokens to experts
hs_rep = hs.repeat_interleave(top_k, dim=0)
y = torch.empty_like(hs_rep)
for eid in range(int(self.num_experts)):
mask = flat_idx == eid
if mask.any():
y[mask] = self.experts[eid](hs_rep[mask])
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
out = y.view(bsz, seqlen, hdim)
return (out, logits)
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class MixtralAdapter(BaseMoEAdapter): class MixtralAdapter(BaseMoEAdapter):
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
routing so that biased logits drive expert *selection* while unbiased
softmax scores drive mixture *weights*.
Works with transformers v5 where experts are fused 3D tensors and
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
"""
family = "mixtral" family = "mixtral"
def matches(self, model: nn.Module) -> bool: def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral" return (
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: )
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_mixtral_forward(moe_layer, shim)
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules(): for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"): if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m yield m
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None: def prepare(
if getattr(moe_layer, "_afb_patched", False): self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
return ) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
shim_ref = shim if not self.uses_kernel_routing(moe_layer):
self._patch_router(moe_layer)
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef] else:
batch_size, sequence_length, hidden_dim = hidden_states.shape LOG.info(
if self.training and getattr(self, "jitter_noise", 0) > 0: "AuxFreeMoE: kernel backend detected on %s; "
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_( "skipping router patch (kernel routing handles bias)",
1.0 - self.jitter_noise, 1.0 + self.jitter_noise type(moe_layer).__name__,
)
flat_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(flat_states)
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
top_k = int(getattr(self, "_afb_top_k", self.top_k))
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
routing_weights = routing_weights.to(flat_states.dtype)
flat_idx = selected_experts.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=flat_states.dtype,
device=flat_states.device,
) )
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) def _patch_router(self, moe_layer: nn.Module) -> None:
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() """Patch the TopKRouter to inject aux-free bias into expert selection."""
for expert_idx_tensor in expert_hit: gate = getattr(moe_layer, "gate", None)
expert_idx = int(expert_idx_tensor.squeeze().item()) if gate is None:
expert_layer = self.experts[expert_idx] LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
mask = expert_mask[expert_idx].squeeze(0) return
idx, top_x = torch.where(mask) if getattr(gate, "_afb_patched", False):
current_state = flat_states[None, top_x].reshape(-1, hidden_dim) return
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) # Capture reference to the MoE block for bias / counts access.
return final_hidden_states, router_logits block_ref = moe_layer
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined] def afb_router_forward(self, hidden_states: torch.Tensor):
setattr(moe_layer, "_afb_patched", True) hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight)
router_probs = F.softmax(router_logits.float(), dim=-1)
# Biased selection, unbiased weights
bias = block_ref._afb_bias
biased = router_probs + bias
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
router_scores = torch.gather(router_probs, 1, router_indices)
# Renormalize (Mixtral always normalizes; Qwen checks config)
if getattr(self, "norm_topk_prob", True):
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
# Accumulate counts for the bias-update callback
flat_idx = router_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=self.num_experts)
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
return router_probs, router_scores, router_indices
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
gate._afb_patched = True
moe_layer._afb_patched = True
class Qwen3Adapter(MixtralAdapter): class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe" family = "qwen3_moe"
def matches(self, model: nn.Module) -> bool: def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe") return getattr(getattr(model, "config", object()), "model_type", "") in (
"qwen3_moe",
"qwen2_moe",
)
class Qwen35MoeAdapter(MixtralAdapter):
"""Adapter for Qwen 3.5 MoE models.
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
is handled by the block forward (untouched by router-level patching).
"""
family = "qwen3_5_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in (
"qwen3_5_moe",
"qwen3_5_moe_text",
)
class BailingAdapter(BaseMoEAdapter): class BailingAdapter(BaseMoEAdapter):
@@ -207,11 +249,15 @@ class BailingAdapter(BaseMoEAdapter):
def get_num_experts(self, moe_layer: nn.Module) -> int: def get_num_experts(self, moe_layer: nn.Module) -> int:
if hasattr(moe_layer, "num_experts"): if hasattr(moe_layer, "num_experts"):
return int(getattr(moe_layer, "num_experts")) return int(moe_layer.num_experts)
cfg = getattr(moe_layer, "config", None) cfg = getattr(moe_layer, "config", None)
return int(getattr(cfg, "num_experts")) if cfg is None:
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
return int(cfg.num_experts)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim) self._register_aux_buffers(moe_layer, handle, shim)
self._patch_bailing_gate(moe_layer) self._patch_bailing_gate(moe_layer)
@@ -227,7 +273,7 @@ class BailingAdapter(BaseMoEAdapter):
flat = hidden_states.view(-1, hidden_states.shape[-1]) flat = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(flat.float(), self.weight.float()) logits = F.linear(flat.float(), self.weight.float())
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype) scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = getattr(moe_layer, "_afb_bias") bias = moe_layer._afb_bias
biased_scores = scores_unbiased + bias biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores) topk_vals, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx) weights = torch.gather(scores_unbiased, 1, topk_idx)
@@ -238,12 +284,12 @@ class BailingAdapter(BaseMoEAdapter):
flat_topk = topk_idx.reshape(-1) flat_topk = topk_idx.reshape(-1)
counts = torch.bincount(flat_topk, minlength=bias.numel()) counts = torch.bincount(flat_topk, minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype)) moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
return topk_idx, weights.to(hidden_states.dtype), logits return topk_idx, weights.to(hidden_states.dtype), logits
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined] gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
setattr(gate, "_afb_patched", True) gate._afb_patched = True
class Llama4Adapter(BaseMoEAdapter): class Llama4Adapter(BaseMoEAdapter):
@@ -257,7 +303,9 @@ class Llama4Adapter(BaseMoEAdapter):
if m.__class__.__name__ == "Llama4TextMoe": if m.__class__.__name__ == "Llama4TextMoe":
yield m yield m
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim) self._register_aux_buffers(moe_layer, handle, shim)
self._patch_llama4_router(moe_layer) self._patch_llama4_router(moe_layer)
@@ -270,9 +318,13 @@ class Llama4Adapter(BaseMoEAdapter):
return return
def afb_router_forward(self, hidden_states: torch.Tensor): def afb_router_forward(self, hidden_states: torch.Tensor):
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1]) flat = (
hidden_states
if hidden_states.dim() == 2
else hidden_states.view(-1, hidden_states.shape[-1])
)
router_logits = F.linear(flat, self.weight, self.bias) router_logits = F.linear(flat, self.weight, self.bias)
bias = getattr(moe_layer, "_afb_bias") bias = moe_layer._afb_bias
biased_logits = router_logits + bias biased_logits = router_logits + bias
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1) _, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
unbiased_top = torch.gather(router_logits, 1, router_indices) unbiased_top = torch.gather(router_logits, 1, router_indices)
@@ -281,15 +333,17 @@ class Llama4Adapter(BaseMoEAdapter):
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype) router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel()) counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype)) moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
return router_scores, router_logits return router_scores, router_logits
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined] router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
setattr(router, "_afb_patched", True) router._afb_patched = True
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]: def discover_and_prepare_layers(
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers. """Discover MoE layers using the first matching adapter and attach per-layer buffers.
Returns a list of layer handles for later routing patching and updates. Returns a list of layer handles for later routing patching and updates.
@@ -313,7 +367,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
try: try:
top_k = adapter.get_top_k(layer) top_k = adapter.get_top_k(layer)
nE = adapter.get_num_experts(layer) nE = adapter.get_num_experts(layer)
except Exception: except (AttributeError, TypeError, ValueError):
continue continue
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k) handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
@@ -321,5 +375,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
handles.append(handle) handles.append(handle)
idx += 1 idx += 1
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing") LOG.info(
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing"
)
return handles return handles

View File

@@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection.
from __future__ import annotations from __future__ import annotations
from typing import Optional, Any from typing import Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -21,6 +21,7 @@ from .adapters import (
Llama4Adapter, Llama4Adapter,
MixtralAdapter, MixtralAdapter,
Qwen3Adapter, Qwen3Adapter,
Qwen35MoeAdapter,
discover_and_prepare_layers, discover_and_prepare_layers,
) )
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
@@ -47,9 +48,11 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
# Iterate prepared MoE layers and apply the bias update rule. # Iterate prepared MoE layers and apply the bias update rule.
self.shim.begin_step() self.shim.begin_step()
for layer in self.layer_modules: for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"): if not hasattr(layer, "_afb_counts") or not hasattr(
layer, "_afb_layer_idx"
):
continue continue
counts = getattr(layer, "_afb_counts") counts = layer._afb_counts
if counts is None: if counts is None:
continue continue
counts = self.shim.all_reduce_counts(counts) counts = self.shim.all_reduce_counts(counts)
@@ -57,7 +60,7 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
if layer_idx is None: if layer_idx is None:
counts.zero_() counts.zero_()
continue continue
bias = getattr(layer, "_afb_bias") bias = layer._afb_bias
counts_for_update = counts.to(bias.device) counts_for_update = counts.to(bias.device)
tokens_seen = int(counts_for_update.sum().item()) tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update # local layer-state EMA and bias update
@@ -143,12 +146,16 @@ class AuxFreeMoEPlugin(BasePlugin):
return return
# Be conservative — skip known native aux-free families # Be conservative — skip known native aux-free families
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in ( native_auxfree = getattr(
getattr(model, "config", object()), "model_type", ""
) in (
"deepseek_v3", "deepseek_v3",
"glm4_moe", "glm4_moe",
) )
if native_auxfree: if native_auxfree:
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching") LOG.info(
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
)
return return
# Build aux-free state and shim # Build aux-free state and shim
@@ -171,6 +178,7 @@ class AuxFreeMoEPlugin(BasePlugin):
adapters: list[BaseMoEAdapter] = [ adapters: list[BaseMoEAdapter] = [
MixtralAdapter(), MixtralAdapter(),
Qwen3Adapter(), Qwen3Adapter(),
Qwen35MoeAdapter(),
BailingAdapter(), BailingAdapter(),
Llama4Adapter(), Llama4Adapter(),
] ]
@@ -178,7 +186,7 @@ class AuxFreeMoEPlugin(BasePlugin):
# For initial state sizing, we conservatively assume the first discovered layer defines nE # For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0 n_layers = 0
n_experts = None n_experts = None
for m in model.modules(): for _m in model.modules():
n_layers += 1 # upper bound — we will re-use bias slots sparsely n_layers += 1 # upper bound — we will re-use bias slots sparsely
device = next(model.parameters(), torch.tensor(0)).device device = next(model.parameters(), torch.tensor(0)).device
if n_layers <= 0: if n_layers <= 0:
@@ -186,7 +194,9 @@ class AuxFreeMoEPlugin(BasePlugin):
if n_experts is None: if n_experts is None:
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead # we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2 n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg) state = AuxFreeState(
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
)
ep_size = getattr(cfg, "expert_parallel_size", None) ep_size = getattr(cfg, "expert_parallel_size", None)
ep_group = None ep_group = None
if sync_group == "ep": if sync_group == "ep":
@@ -207,11 +217,15 @@ class AuxFreeMoEPlugin(BasePlugin):
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]: def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
if not dist.is_available() or not dist.is_initialized(): if not dist.is_available() or not dist.is_initialized():
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world") LOG.warning(
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
)
return None return None
ep_size = getattr(cfg, "expert_parallel_size", None) ep_size = getattr(cfg, "expert_parallel_size", None)
if not ep_size or ep_size <= 1: if not ep_size or ep_size <= 1:
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world") LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
)
return None return None
world = dist.get_world_size() world = dist.get_world_size()
if world % ep_size != 0: if world % ep_size != 0:

View File

@@ -240,7 +240,16 @@ def _softmax_topk_route(
top_k = base_gate.top_k top_k = base_gate.top_k
num_experts = base_gate.num_experts num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# Aux-free bias: biased selection, unbiased weights
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = routing_weights + afb_bias
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
routing_weights = routing_weights.gather(1, selected_experts)
_accumulate_afb_counts(moe_block, selected_experts)
else:
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True): if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
else: else:
scores_for_choice = router_probs scores_for_choice = router_probs
# Aux-free bias: stacks on top of e_score_correction_bias for selection
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = scores_for_choice + afb_bias
# Group-based selection: pick top groups, mask the rest # Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1) n_group = getattr(moe_block, "n_group", 1)
if n_group > 1: if n_group > 1:
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
# Gather weights from original sigmoid scores (not bias-corrected) # Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices) topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Optional renormalization + scaling # Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True): if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
) )
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
"""Accumulate per-expert token counts for aux-free bias updates."""
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return
flat_idx = topk_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
afb_counts.add_(counts.to(afb_counts.dtype))
# ============================================================================= # =============================================================================
# Shared expert helpers # Shared expert helpers
# ============================================================================= # =============================================================================

View File

@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
Each model type maps to a (routing_fn, activation_type, router_attr) triple. Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
as buffers. The routing functions transparently inject the bias into expert
*selection* (biased topk) while keeping mixture *weights* from unbiased
scores, then accumulate per-expert token counts for the post-step bias update.
""" """
import torch import torch
@@ -101,17 +107,25 @@ def softmax_topk_routing(
router_logits = F.linear(hidden_states, gate.weight) # [T, E] router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Aux-free bias: biased selection, unbiased weights
afb_bias = getattr(moe_block, "_afb_bias", None)
scores_for_choice = router_probs
if afb_bias is not None:
scores_for_choice = router_probs + afb_bias
# Select top-k experts per token # Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
# When aux-free bias is active, gather unbiased weights and accumulate counts
if afb_bias is not None:
top_values = router_probs.gather(1, top_indices)
_accumulate_afb_counts(moe_block, top_indices)
# Renormalize if configured (default True for models without the attribute, # Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize) # e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True): if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True) top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)
# Flatten for moe_general_routing_inputs. # Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout: # Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE. # [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
router_logits = F.linear(hidden_states, gate.weight) # [T, E] router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Aux-free bias: inject before group selection / topk
afb_bias = getattr(moe_block, "_afb_bias", None)
scores_for_choice = router_probs scores_for_choice = router_probs
if afb_bias is not None:
scores_for_choice = router_probs + afb_bias
# Group selection: pick top groups, mask the rest # Group selection: pick top groups, mask the rest
if n_group > 1: if n_group > 1:
@@ -164,6 +182,10 @@ def softmax_group_topk_routing(
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices) topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Renormalization + scaling # Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob: if norm_topk_prob:
@@ -233,6 +255,11 @@ def sigmoid_topk_routing(
) )
scores_for_choice = router_probs + e_score_correction_bias scores_for_choice = router_probs + e_score_correction_bias
# Aux-free bias: stacks on top of e_score_correction_bias for selection
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = scores_for_choice + afb_bias
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1) # Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
if n_group > 1: if n_group > 1:
group_scores = ( group_scores = (
@@ -256,6 +283,10 @@ def sigmoid_topk_routing(
# Gather weights from original sigmoid scores (not bias-corrected) # Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices) topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Optional renormalization + scaling # Optional renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob: if norm_topk_prob:
@@ -276,3 +307,19 @@ def sigmoid_topk_routing(
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
"""Accumulate per-expert token counts for the aux-free bias update.
Called when ``moe_block._afb_bias`` is present (registered by the
``aux_free_router`` plugin). The counts are later consumed by the
``MoeAuxFreeBiasUpdateCallback`` at each training step.
"""
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return
num_experts = afb_counts.numel()
flat_idx = topk_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=num_experts)
afb_counts.add_(counts.to(afb_counts.dtype))

View File

@@ -2,14 +2,13 @@ import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
from importlib import util as importlib_util
from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from importlib import util as importlib_util
from pathlib import Path
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
@@ -96,16 +95,21 @@ def _build_bailing_model():
def _build_llama4_model(): def _build_llama4_model():
from transformers import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
config = Llama4TextConfig( # Build config without __post_init__ validation (works around a
# huggingface_hub strict-dataclass type mismatch for layer_types).
config = object.__new__(__import__("transformers").Llama4TextConfig)
config.__dict__.update(
hidden_size=16, hidden_size=16,
intermediate_size=32, intermediate_size=32,
num_local_experts=4, num_local_experts=4,
num_attention_heads=2, num_attention_heads=2,
num_key_value_heads=2, num_key_value_heads=2,
num_experts_per_tok=2, num_experts_per_tok=2,
num_hidden_layers=2,
hidden_act="silu",
layer_types=None,
) )
layer = Llama4TextMoe(config) layer = Llama4TextMoe(config)
@@ -148,6 +152,38 @@ def _build_mixtral_model():
return DummyModel(layer), layer return DummyModel(layer), layer
def _build_qwen35_moe_model():
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeSparseMoeBlock,
)
config = Qwen3_5MoeTextConfig(
hidden_size=16,
moe_intermediate_size=32,
shared_expert_intermediate_size=32,
num_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
num_hidden_layers=2,
)
layer = Qwen3_5MoeSparseMoeBlock(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="qwen3_5_moe")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg, *, args=None, state=None, control=None): def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
if args is None: if args is None:
args = SimpleNamespace(logging_steps=1) args = SimpleNamespace(logging_steps=1)
@@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
_run_callback(plugin, cfg) _run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0) self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))) self.assertFalse(
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
)
def test_llama4_adapter_biases_router_selection(self): def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model() model, layer = _build_llama4_model()
@@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
_run_callback(plugin, cfg) _run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0) self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))) self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_bias_warmup_respected(self): def test_bias_warmup_respected(self):
model, block = _build_bailing_model() model, block = _build_bailing_model()
@@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase):
# Warmup steps should leave bias untouched. # Warmup steps should leave bias untouched.
_step() _step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))) self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
_step() _step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))) self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
# Third step exceeds warmup -> bias should update. # Third step exceeds warmup -> bias should update.
_step() _step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0) self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_respects_native_forward(self): def test_mixtral_adapter_patches_router_not_forward(self):
"""Verify that aux-free patches the router (gate) only, and the
v5 block forward signature (single tensor return) is preserved."""
model, layer = _build_mixtral_model() model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
baseline_out, baseline_logits = layer(hidden.clone())
cfg = _cfg() cfg = _cfg()
plugin = AuxFreeMoEPlugin() plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model) plugin.post_model_build(cfg, model)
patched_out, patched_logits = layer(hidden.clone()) # Gate should be patched, not the block forward
self.assertTrue(torch.allclose(baseline_out, patched_out)) self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(torch.allclose(baseline_logits, patched_logits)) self.assertTrue(getattr(layer, "_afb_patched", False))
# v5 block forward returns a single tensor (not a tuple with logits)
hidden = torch.randn(2, 3, layer.config.hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0) self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg) _run_callback(plugin, cfg)
def test_mixtral_adapter_bias_affects_selection(self):
"""When bias is large for one expert, it should be selected more often."""
model, layer = _build_mixtral_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Set a large bias for expert 0 to force its selection
layer._afb_bias.zero_()
layer._afb_bias[0] = 10.0
hidden = torch.randn(2, 8, layer.config.hidden_size)
num_tokens = 2 * 8 # batch * seq
layer(hidden)
# With top_k=2, expert 0 should appear in every token's selection
# (once per token = num_tokens counts, not num_tokens * top_k)
counts = layer._afb_counts.clone()
self.assertEqual(
int(counts[0].item()),
num_tokens,
msg="Expert 0 should be selected for every token when heavily biased",
)
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
output includes shared expert contribution."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Gate should be patched
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# Shared expert should be unmodified
self.assertTrue(hasattr(layer, "shared_expert"))
self.assertTrue(hasattr(layer, "shared_expert_gate"))
# Forward should return a single tensor (shared + routed)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 3, hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
def test_qwen35_moe_adapter_bias_updates(self):
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 4, hidden_size)
layer(hidden)
# Bias should start at zero
self.assertTrue(
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
)
_run_callback(plugin, cfg)
# After callback: counts reset, EMA updated, bias updated
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_qwen35_moe_adapter_model_type_matching(self):
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
adapter = Qwen35MoeAdapter()
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
model_text = SimpleNamespace(
config=SimpleNamespace(model_type="qwen3_5_moe_text")
)
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
self.assertTrue(adapter.matches(model_moe))
self.assertTrue(adapter.matches(model_text))
self.assertFalse(adapter.matches(model_other))
def test_ep_group_resolution_deferred_until_dist_ready(self): def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()
@@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
tmp_init = tempfile.NamedTemporaryFile(delete=False) tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close() tmp_init.close()
init_method = f"file://{tmp_init.name}" init_method = f"file://{tmp_init.name}"
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0) dist.init_process_group(
backend="gloo", init_method=init_method, world_size=1, rank=0
)
try: try:
hidden = torch.randn(2, 3, block.config.hidden_size) hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden) block(hidden)
@@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_telemetry_logging(self): def test_telemetry_logging(self):
model, layer = _build_mixtral_model() model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg() cfg = _cfg()
plugin = AuxFreeMoEPlugin() plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model) plugin.post_model_build(cfg, model)
@@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase):
self.assertIn("moe_afb/l0_load_max", telemetry) self.assertIn("moe_afb/l0_load_max", telemetry)
self.assertIn("moe_afb/l0_bias_abs_max", telemetry) self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
def test_get_num_experts_v5_attribute_paths(self):
"""Verify get_num_experts works with v5 attribute layout where
num_experts is on gate/experts sub-modules, not the block."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
block = SimpleNamespace(
gate=SimpleNamespace(num_experts=8),
experts=SimpleNamespace(num_experts=8),
)
self.assertEqual(adapter.get_num_experts(block), 8)
# Also works when num_experts is directly on block
block2 = SimpleNamespace(num_experts=4)
self.assertEqual(adapter.get_num_experts(block2), 4)
class TestAuxFreeKernelComposition(unittest.TestCase):
"""Tests that aux-free bias composes correctly with kernel routing."""
def test_sonicmoe_softmax_routing_with_afb_bias(self):
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
# Build a mock MoE block with gate attributes
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, hidden_dim)
# Baseline: no bias
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
hidden, moe_block
)
self.assertEqual(scores_base.shape[0], T * top_k)
# Now register aux-free buffers and set heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
hidden, moe_block
)
# Expert 0 should be selected for every token
self.assertTrue(
(exp_biased == 0).any(),
"Expert 0 should appear in selections when heavily biased",
)
# Counts should have been accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
# Total counts should equal T * top_k
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_sonicmoe_routing_without_bias_unchanged(self):
"""Without _afb_bias, routing should produce identical results."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(4, hidden_dim)
# Without _afb_bias attribute
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
# With _afb_bias = zeros (should be equivalent)
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_counts = torch.zeros(num_experts)
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
torch.testing.assert_close(scores1, scores2)
torch.testing.assert_close(exp1, exp2)
@unittest.skipUnless(
importlib_util.find_spec("triton") is not None,
"triton not installed (required by scattermoe)",
)
def test_scattermoe_softmax_routing_with_afb_bias(self):
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
gate_weight = torch.randn(num_experts, hidden_dim)
base_gate = SimpleNamespace(
top_k=top_k,
num_experts=num_experts,
norm_topk_prob=True,
weight=gate_weight,
)
moe_block = SimpleNamespace()
hidden = torch.randn(T, hidden_dim)
# Baseline without bias
w_base, e_base, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# With heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
w_biased, e_biased, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# Expert 0 should appear in all selections
self.assertTrue((e_biased == 0).any())
# Counts accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_kernel_routing_skips_router_patch(self):
"""When a kernel backend has patched the block class, the adapter
should skip patching the router (buffers are still registered)."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Create a mock layer whose class has _original_forward (SonicMoE marker)
class PatchedBlock(nn.Module):
_original_forward = True # SonicMoE marker
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16) # placeholder
layer = PatchedBlock()
self.assertTrue(adapter.uses_kernel_routing(layer))
# Gate should NOT be patched (kernel handles routing)
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
def test_adapter_buffers_registered_even_with_kernel(self):
"""Even when kernel routing is active, aux-free buffers must be
registered on the MoE block so the kernel routing can find them."""
from axolotl.integrations.aux_free_router.adapters import (
LayerHandle,
MixtralAdapter,
)
from axolotl.integrations.aux_free_router.core import (
AuxFreeConfig,
AuxFreeShim,
AuxFreeState,
)
class PatchedBlock(nn.Module):
_original_forward = True
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16)
layer = PatchedBlock()
adapter = MixtralAdapter()
cfg = AuxFreeConfig()
state = AuxFreeState(
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
)
shim = AuxFreeShim(state=state)
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
adapter.prepare(layer, handle, shim)
# Buffers should be registered for kernel routing to use
self.assertTrue(hasattr(layer, "_afb_bias"))
self.assertTrue(hasattr(layer, "_afb_counts"))
self.assertTrue(hasattr(layer, "_afb_ema"))
# But gate should NOT be patched
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()