Compare commits

...

14 Commits

Author SHA1 Message Date
Wing Lian
6636e5de7e address PR code review 2026-03-22 17:23:12 +00:00
Wing Lian
0a566d7a15 chore: lint 2026-03-22 12:02:44 -04:00
Wing Lian
5acb1b0ade update for transformers v5 for experts parameters and compose with moe kernels 2026-03-22 11:52:34 -04:00
lhl
4009a2ba5f reordered our tests to mirror llm_compressor for prepare_plugins/validate order 2026-03-22 11:16:03 -04:00
lhl
66b2ab8414 move configs from global config to plugin specific args 2026-03-22 11:16:03 -04:00
lhl
676d5e855d improve: align aux-free telemetry with Trainer logging 2026-03-22 11:16:03 -04:00
lhl
966a4555db fix: make aux-free mixtral adapter GPU-safe 2026-03-22 11:16:03 -04:00
lhl
ad0c825bcb sample packing and telemetry docs 2026-03-22 11:16:03 -04:00
lhl
46d677876e tests: add ring and llama4 aux-free smokes
- add hf tiny model coverage for Ring 2.0 and Llama 4 adapters

- broaden bailing adapter detection for ring configs
2026-03-22 11:16:03 -04:00
lhl
6eac9ac372 aux_free_router: emit telemetry metrics
- log per-layer load stats, bias magnitude, and oscillation rate

- honor configurable telemetry interval with Trainer logging integration
2026-03-22 11:16:02 -04:00
lhl
949cdf01eb tests: extend aux-free coverage
- add warmup, EP sync, and mixtral parity unit checks
2026-03-22 11:16:02 -04:00
lhl
a0019021dd aux_free_router: sync shim state
- drive warmup-aware bias updates and register live buffers
2026-03-22 11:16:02 -04:00
lhl
2af7475fdf Add ring/llama4 aux-free adapters and EP sync support 2026-03-22 11:16:02 -04:00
lhl
3e4688289c feat(moe-aux-loss-free): aux-free MoE plugin (Mixtral/Qwen3), EMA bias updates, config keys; E2E smoke + parity tests 2026-03-22 11:16:02 -04:00
18 changed files with 2129 additions and 8 deletions

View File

@@ -0,0 +1,50 @@
# Aux-Loss-Free MoE Router Plugin
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
Summary
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
Enable
- Add the plugin to your YAML, then set the aux-free toggle:
plugins:
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
moe_balance_type: noaux_tc
moe_update_rate: 0.01 # default if unset
moe_update_momentum: 0.9 # default if unset
moe_bias_cap: 2.0 # default if unset
moe_afb_warmup_steps: 100 # optional
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
Config keys
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
- moe_update_rate: bias update rate (gamma). Default: 0.01.
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
Compatibility
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
Notes
- If you also enable Ligers aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainers `logging_steps` cadence.
- Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert.
Telemetry metrics
- `moe_afb/l{idx}_load_min|mean|max`: token frequency per expert after reduction (01 range, sums to 1).
- `moe_afb/l{idx}_bias_abs_max`: absolute maximum of the learned bias for the layer.
- `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator).
Usage tips
- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainers logging cadence.
- Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term.

View File

@@ -0,0 +1,9 @@
"""Aux-loss-free (AFB) MoE router integration package."""
from .args import AuxFreeRouterArgs
from .plugin import AuxFreeMoEPlugin
__all__ = [
"AuxFreeMoEPlugin",
"AuxFreeRouterArgs",
]

View File

@@ -0,0 +1,393 @@
"""Architecture-specific adapters for aux-loss-free MoE routing.
Each adapter discovers MoE layers for a model family and patches only the
router/gate to inject per-expert bias into expert selection while keeping
mixture weights from unbiased logits. Expert dispatch is left untouched so
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
import torch.nn.functional as F
from torch import nn
from axolotl.utils.logging import get_logger
from .core import AuxFreeShim
LOG = get_logger(__name__)
@dataclass
class LayerHandle:
layer: nn.Module
layer_idx: int
num_experts: int
top_k: int
class BaseMoEAdapter:
"""Base adapter that discovers MoE layers and patches their routing.
Concrete adapters implement discovery, attribute extraction, and
architecture-specific router patching.
"""
family: str = "generic"
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
return False
def find_moe_layers(
self, model: nn.Module
) -> Iterable[nn.Module]: # pragma: no cover
return []
def get_top_k(self, moe_layer: nn.Module) -> int:
"""Resolve top_k from the MoE layer, checking common attribute paths."""
for attr_path in [
("top_k",),
("num_experts_per_tok",),
("gate", "top_k"),
("router", "top_k"),
]:
obj: object = moe_layer
for attr in attr_path:
obj = getattr(obj, attr, None)
if obj is None:
break
if isinstance(obj, int):
return obj
return 2
def get_num_experts(self, moe_layer: nn.Module) -> int:
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
for attr_path in [
("num_experts",),
("num_local_experts",),
("gate", "num_experts"),
("router", "num_experts"),
("experts", "num_experts"),
]:
obj: object = moe_layer
for attr in attr_path:
obj = getattr(obj, attr, None)
if obj is None:
break
if isinstance(obj, int):
return obj
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
# Best-effort: zero router aux loss coef if present
if hasattr(model_or_layer, "router_aux_loss_coef"):
try:
model_or_layer.router_aux_loss_coef = 0.0
except Exception: # pragma: no cover - non-critical
LOG.debug(
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
type(model_or_layer).__name__,
exc_info=True,
)
def _register_aux_buffers(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
p = next(moe_layer.parameters(), None)
b = next(moe_layer.buffers(), None)
device = (
p.device
if p is not None
else (b.device if b is not None else torch.device("cpu"))
)
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer(
"_afb_bias", torch.zeros(handle.num_experts, device=device)
)
if not hasattr(moe_layer, "_afb_counts"):
moe_layer.register_buffer(
"_afb_counts", torch.zeros(handle.num_experts, device=device)
)
if not hasattr(moe_layer, "_afb_ema"):
moe_layer.register_buffer(
"_afb_ema", torch.zeros(handle.num_experts, device=device)
)
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
shim.register_layer_buffers(handle.layer_idx, moe_layer)
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
"""Attach per-layer buffers. Subclasses override to also patch routing."""
self._register_aux_buffers(moe_layer, handle, shim)
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
already replaced the block forward, meaning the routing is handled
inside the kernel forward and we should NOT patch the router."""
cls = type(moe_layer)
# SonicMoE stores the original forward when it patches a class.
if hasattr(cls, "_original_forward"):
return True
# ScatterMoE replaces via kernels library; check for the marker.
if hasattr(cls, "_kernel_forward"):
return True
return False
class MixtralAdapter(BaseMoEAdapter):
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
routing so that biased logits drive expert *selection* while unbiased
softmax scores drive mixture *weights*.
Works with transformers v5 where experts are fused 3D tensors and
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
"""
family = "mixtral"
def matches(self, model: nn.Module) -> bool:
return (
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
)
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
if not self.uses_kernel_routing(moe_layer):
self._patch_router(moe_layer)
else:
LOG.info(
"AuxFreeMoE: kernel backend detected on %s; "
"skipping router patch (kernel routing handles bias)",
type(moe_layer).__name__,
)
def _patch_router(self, moe_layer: nn.Module) -> None:
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
gate = getattr(moe_layer, "gate", None)
if gate is None:
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
return
if getattr(gate, "_afb_patched", False):
return
# Capture reference to the MoE block for bias / counts access.
block_ref = moe_layer
def afb_router_forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight)
router_probs = F.softmax(router_logits.float(), dim=-1)
# Biased selection, unbiased weights
bias = block_ref._afb_bias
biased = router_probs + bias
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
router_scores = torch.gather(router_probs, 1, router_indices)
# Renormalize (Mixtral always normalizes; Qwen checks config)
if getattr(self, "norm_topk_prob", True):
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
# Accumulate counts for the bias-update callback
flat_idx = router_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=self.num_experts)
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
return router_probs, router_scores, router_indices
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
gate._afb_patched = True
moe_layer._afb_patched = True
class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in (
"qwen3_moe",
"qwen2_moe",
)
class Qwen35MoeAdapter(MixtralAdapter):
"""Adapter for Qwen 3.5 MoE models.
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
is handled by the block forward (untouched by router-level patching).
"""
family = "qwen3_5_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in (
"qwen3_5_moe",
"qwen3_5_moe_text",
)
class BailingAdapter(BaseMoEAdapter):
family = "bailing_moe"
def matches(self, model: nn.Module) -> bool:
cfg = getattr(model, "config", None)
if cfg is None:
return False
model_type = getattr(cfg, "model_type", "") or ""
if model_type in ("bailing_moe", "bailing_moe_v2", "ring_moe", "ring"):
return True
cfg_name = cfg.__class__.__name__.lower()
return "bailingmoev2" in cfg_name or "ring" in cfg_name
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
yield m
def get_num_experts(self, moe_layer: nn.Module) -> int:
if hasattr(moe_layer, "num_experts"):
return int(moe_layer.num_experts)
cfg = getattr(moe_layer, "config", None)
if cfg is None:
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
return int(cfg.num_experts)
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_bailing_gate(moe_layer)
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
gate = getattr(moe_layer, "gate", None)
if gate is None:
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
return
if getattr(gate, "_afb_patched", False):
return
def afb_gate_forward(self, hidden_states: torch.Tensor):
flat = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(flat.float(), self.weight.float())
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = moe_layer._afb_bias
biased_scores = scores_unbiased + bias
_, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx)
if self.top_k > 1:
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
weights = weights / denom
weights = weights * self.routed_scaling_factor
flat_topk = topk_idx.reshape(-1)
counts = torch.bincount(flat_topk, minlength=bias.numel())
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
return topk_idx, weights.to(hidden_states.dtype), logits
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
gate._afb_patched = True
class Llama4Adapter(BaseMoEAdapter):
family = "llama4"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "Llama4TextMoe":
yield m
def prepare(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_llama4_router(moe_layer)
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
router = getattr(moe_layer, "router", None)
if router is None:
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
return
if getattr(router, "_afb_patched", False):
return
def afb_router_forward(self, hidden_states: torch.Tensor):
flat = (
hidden_states
if hidden_states.dim() == 2
else hidden_states.view(-1, hidden_states.shape[-1])
)
router_logits = F.linear(flat, self.weight, self.bias)
bias = moe_layer._afb_bias
biased_logits = router_logits + bias
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
unbiased_top = torch.gather(router_logits, 1, router_indices)
router_scores = torch.full_like(router_logits, float("-inf"))
router_scores.scatter_(1, router_indices, unbiased_top)
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
return router_scores, router_logits
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
router._afb_patched = True
def discover_and_prepare_layers(
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
Returns a list of layer handles for later routing patching and updates.
"""
handles: list[LayerHandle] = []
adapter: Optional[BaseMoEAdapter] = None
for a in adapters:
if a.matches(model):
adapter = a
break
if adapter is None:
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
return handles
# disable aux loss at model level if possible
adapter.disable_aux_loss(getattr(model, "config", model))
idx = 0
for layer in adapter.find_moe_layers(model):
try:
top_k = adapter.get_top_k(layer)
nE = adapter.get_num_experts(layer)
except (AttributeError, TypeError, ValueError):
continue
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
adapter.prepare(layer, handle, shim)
handles.append(handle)
idx += 1
LOG.info(
"AuxFreeMoE: prepared %d %s layers for aux-free routing",
len(handles),
adapter.family,
)
return handles

View File

@@ -0,0 +1,71 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plugin args for the Aux-Loss-Free MoE router integration.
"""
from typing import Literal
from pydantic import BaseModel, Field
class AuxFreeRouterArgs(BaseModel):
"""
Input args for Aux-Loss-Free MoE routing.
"""
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
default=None,
json_schema_extra={
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, "
"'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. "
"Defaults to model's native behavior when unset."
},
)
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
"If unset, plugin default is 0.01."
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (0-1). "
"If unset, plugin default is 0.9."
},
)
moe_bias_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Absolute clamp for expert bias magnitude. "
"If unset, plugin default is 2.0."
},
)
moe_afb_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of initial steps to delay aux-free bias updates, "
"allowing routing to stabilize. If unset, plugin default is 0."
},
)
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
default=None,
json_schema_extra={
"description": "Reduction group for expert load counts: 'world' (DP) or "
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
},
)

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@dataclass
class AuxFreeConfig:
rate: float = 0.01
momentum: float = 0.9
bias_cap: float = 2.0
warmup_steps: int = 0
sync_group: str = "world" # or "ep"
class AuxFreeState:
"""Holds per-layer bias and EMA load buffers."""
def __init__(
self,
num_layers: int,
num_experts: int,
device: torch.device,
cfg: AuxFreeConfig,
):
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.ema_load = [
torch.zeros(num_experts, device=device) for _ in range(num_layers)
]
self.cfg = cfg
self.steps = 0
class AuxFreeShim:
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
def __init__(
self,
state: AuxFreeState,
ep_group: Optional[dist.ProcessGroup] = None,
ep_size: Optional[int] = None,
):
self.state = state
self.ep_group = ep_group
self._ep_size = ep_size
self._ep_group_pending = (
self.state.cfg.sync_group == "ep" and self.ep_group is None
)
self._layer_modules: dict[int, torch.nn.Module] = {}
self._prev_bias_sign: dict[int, torch.Tensor] = {}
@torch.no_grad()
def select_experts(
self, layer_idx: int, logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_bias"):
b = module._afb_bias
else:
b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer
_topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
"""Bind model buffers so shim updates stay in sync with patched layers."""
self._layer_modules[layer_idx] = module
bias = module._afb_bias
ema = module._afb_ema
# Keep state views pointing to the same tensors to avoid drift.
if layer_idx < len(self.state.bias):
self.state.bias[layer_idx] = bias
if layer_idx < len(self.state.ema_load):
self.state.ema_load[layer_idx] = ema
def begin_step(self) -> None:
"""Call once per optimizer step before per-layer updates."""
self.state.steps += 1
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
return self._prev_bias_sign.get(layer_idx)
@torch.no_grad()
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
self._maybe_init_ep_group()
if not dist.is_available() or not dist.is_initialized():
return counts
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
return counts
@torch.no_grad()
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
cfg = self.state.cfg
if self.state.steps <= cfg.warmup_steps:
return
nE = step_counts.numel()
if tokens_seen <= 0:
return
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_ema"):
ema = module._afb_ema
bias = module._afb_bias
else:
ema = self.state.ema_load[layer_idx]
bias = self.state.bias[layer_idx]
counts = step_counts.to(ema.device)
freq = counts.float() / float(tokens_seen)
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
# optional mean-centering to keep sum(bias) ~ 0
delta = delta - delta.mean()
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
def _maybe_init_ep_group(self) -> None:
if not self._ep_group_pending:
return
if not dist.is_available() or not dist.is_initialized():
return
ep_size = self._ep_size
if not ep_size or ep_size <= 1:
LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
ep_size,
world,
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
if ep_size == world:
self.ep_group = dist.group.WORLD
else:
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
self.ep_group = dist.new_group(ranks)
LOG.info(
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
ep_size,
dist.get_world_size(),
)
self._ep_group_pending = False

View File

@@ -0,0 +1,267 @@
"""Aux-loss-free MoE Router Plugin for Axolotl.
This plugin wires an aux-free gating option into compatible MoE models using
unbiased logits for mixture weights and per-expert biases for top-k selection.
"""
from __future__ import annotations
from typing import Any, Optional
import torch
import torch.distributed as dist
from transformers.trainer_callback import TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .adapters import (
BailingAdapter,
BaseMoEAdapter,
Llama4Adapter,
MixtralAdapter,
Qwen3Adapter,
Qwen35MoeAdapter,
discover_and_prepare_layers,
)
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
LOG = get_logger(__name__)
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
"""Post-step callback to update aux-free biases from accumulated expert counts."""
def __init__(
self,
shim: AuxFreeShim,
layer_modules: list[torch.nn.Module],
trainer: Any,
):
self.shim = shim
self.layer_modules = layer_modules
self.trainer = trainer
self._prev_bias_sign: dict[int, torch.Tensor] = {}
self._telemetry_buffer: dict[int, dict[str, float]] = {}
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
# Iterate prepared MoE layers and apply the bias update rule.
self.shim.begin_step()
for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(
layer, "_afb_layer_idx"
):
continue
counts = layer._afb_counts
if counts is None:
continue
counts = self.shim.all_reduce_counts(counts)
layer_idx = getattr(layer, "_afb_layer_idx", None)
if layer_idx is None:
counts.zero_()
continue
bias = layer._afb_bias
counts_for_update = counts.to(bias.device)
tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
# reset step counts
counts.zero_()
if self._should_log(args, state) and self._telemetry_buffer:
logs: dict[str, float] = {}
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
prefix = f"moe_afb/l{layer_idx}_"
for key, value in metrics.items():
logs[f"{prefix}{key}"] = value
if logs and hasattr(self.trainer, "log"):
self.trainer.log(logs)
self._telemetry_buffer.clear()
return control
def _collect_telemetry(
self,
layer_idx: int,
counts: torch.Tensor,
tokens_seen: int,
bias: torch.Tensor,
) -> None:
if tokens_seen <= 0:
return
freq = counts.float() / float(tokens_seen)
load_min = freq.min().item()
load_mean = freq.mean().item()
load_max = freq.max().item()
bias_abs_max = bias.abs().max().item()
prev_sign = self._prev_bias_sign.get(layer_idx)
current_sign = torch.sign(bias.detach())
if prev_sign is None or prev_sign.shape != current_sign.shape:
oscillation = 0.0
else:
changed = (current_sign != prev_sign) & (
(current_sign != 0) | (prev_sign != 0)
)
if changed.numel() == 0:
oscillation = 0.0
else:
oscillation = changed.float().mean().item()
self._prev_bias_sign[layer_idx] = current_sign.clone()
self._telemetry_buffer[layer_idx] = {
"load_min": load_min,
"load_mean": load_mean,
"load_max": load_max,
"bias_abs_max": bias_abs_max,
"bias_sign_flip_frac": oscillation,
}
def _should_log(self, args, state) -> bool:
interval = getattr(args, "logging_steps", 0)
if not interval:
return False
try:
interval = max(1, int(interval))
except (TypeError, ValueError):
return False
return interval > 0 and state.global_step % interval == 0
class AuxFreeMoEPlugin(BasePlugin):
"""Plugin that enables aux-loss-free routing when configured."""
def __init__(self):
super().__init__()
self._handles: list = []
self._shim: Optional[AuxFreeShim] = None
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
def get_input_args(self):
return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs"
def post_model_build(self, cfg, model):
# Enable only when explicitly requested
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return
# Be conservative — skip known native aux-free families
native_auxfree = getattr(
getattr(model, "config", object()), "model_type", ""
) in (
"deepseek_v3",
"glm4_moe",
)
if native_auxfree:
LOG.info(
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
)
return
# Build aux-free state and shim
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
momentum = (
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
)
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
af_cfg = AuxFreeConfig(
rate=rate,
momentum=momentum,
bias_cap=bias_cap,
warmup_steps=warmup,
sync_group=sync_group,
)
# Discover layers to count the number and experts for state sizing
adapters: list[BaseMoEAdapter] = [
MixtralAdapter(),
Qwen3Adapter(),
Qwen35MoeAdapter(),
BailingAdapter(),
Llama4Adapter(),
]
# For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0
n_experts = None
for _m in model.modules():
n_layers += 1 # upper bound — we will re-use bias slots sparsely
device = next(model.parameters(), torch.tensor(0)).device
if n_layers <= 0:
n_layers = 1
if n_experts is None:
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
)
ep_size = getattr(cfg, "expert_parallel_size", None)
ep_group = None
if sync_group == "ep":
if dist.is_available() and dist.is_initialized():
ep_group = self._resolve_ep_group(cfg)
else:
LOG.info(
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
)
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
# Discover and prepare layers (attach per-layer buffers)
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
LOG.info(
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
)
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
if not dist.is_available() or not dist.is_initialized():
LOG.warning(
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
)
return None
ep_size = getattr(cfg, "expert_parallel_size", None)
if not ep_size or ep_size <= 1:
LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
)
return None
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
ep_size,
world,
)
return None
if ep_size == world:
return dist.group.WORLD
rank = dist.get_rank()
# All ranks must collectively create all EP subgroups in the same order
# to avoid deadlocks (dist.new_group is a collective operation).
world_size = world
my_group = None
for group_start in range(0, world_size, ep_size):
ranks = tuple(range(group_start, group_start + ep_size))
if ranks not in self._ep_group_cache:
self._ep_group_cache[ranks] = dist.new_group(ranks)
if rank in ranks:
my_group = self._ep_group_cache[ranks]
return my_group
def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return []
if self._shim is None:
return []
# gather concrete layer modules from handles
layers = [h.layer for h in self._handles]
cb = MoeAuxFreeBiasUpdateCallback(
self._shim,
layers,
trainer,
)
LOG.info("AuxFreeMoE: registering post-step bias update callback")
return [cb]

View File

@@ -240,7 +240,16 @@ def _softmax_topk_route(
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# Aux-free bias: biased selection, unbiased weights
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = routing_weights + afb_bias
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
routing_weights = routing_weights.gather(1, selected_experts)
_accumulate_afb_counts(moe_block, selected_experts)
else:
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
else:
scores_for_choice = router_probs
# Aux-free bias: stacks on top of e_score_correction_bias for selection
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = scores_for_choice + afb_bias
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
)
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
"""Accumulate per-expert token counts for aux-free bias updates."""
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return
flat_idx = topk_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
afb_counts.add_(counts.to(afb_counts.dtype))
# =============================================================================
# Shared expert helpers
# =============================================================================

View File

@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
as buffers. The routing functions transparently inject the bias into expert
*selection* (biased topk) while keeping mixture *weights* from unbiased
scores, then accumulate per-expert token counts for the post-step bias update.
"""
import torch
@@ -101,17 +107,25 @@ def softmax_topk_routing(
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Aux-free bias: biased selection, unbiased weights
afb_bias = getattr(moe_block, "_afb_bias", None)
scores_for_choice = router_probs
if afb_bias is not None:
scores_for_choice = router_probs + afb_bias
# Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
# When aux-free bias is active, gather unbiased weights and accumulate counts
if afb_bias is not None:
top_values = router_probs.gather(1, top_indices)
_accumulate_afb_counts(moe_block, top_indices)
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Aux-free bias: inject before group selection / topk
afb_bias = getattr(moe_block, "_afb_bias", None)
scores_for_choice = router_probs
if afb_bias is not None:
scores_for_choice = router_probs + afb_bias
# Group selection: pick top groups, mask the rest
if n_group > 1:
@@ -159,11 +177,17 @@ def softmax_group_topk_routing(
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), -float("inf")
)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
@@ -233,6 +257,11 @@ def sigmoid_topk_routing(
)
scores_for_choice = router_probs + e_score_correction_bias
# Aux-free bias: stacks on top of e_score_correction_bias for selection
afb_bias = getattr(moe_block, "_afb_bias", None)
if afb_bias is not None:
scores_for_choice = scores_for_choice + afb_bias
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
if n_group > 1:
group_scores = (
@@ -248,7 +277,9 @@ def sigmoid_topk_routing(
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), -float("inf")
)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
@@ -256,6 +287,10 @@ def sigmoid_topk_routing(
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Accumulate counts for aux-free bias update
if afb_bias is not None:
_accumulate_afb_counts(moe_block, topk_indices)
# Optional renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
@@ -276,3 +311,21 @@ def sigmoid_topk_routing(
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
"""Accumulate per-expert token counts for the aux-free bias update.
Called when ``moe_block._afb_bias`` is present (registered by the
``aux_free_router`` plugin). The counts are later consumed by the
``MoeAuxFreeBiasUpdateCallback`` at each training step.
"""
if hasattr(moe_block, "training") and not moe_block.training:
return
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return
num_experts = afb_counts.numel()
flat_idx = topk_indices.reshape(-1)
counts = torch.bincount(flat_idx, minlength=num_experts)
afb_counts.add_(counts.to(afb_counts.dtype))

View File

@@ -299,6 +299,7 @@ def validate_config(
AxolotlInputConfig = AxolotlInputConfigBase
if cfg.plugins:
prepare_plugins(cfg)
(
AxolotlConfigWCapabilities,
AxolotlInputConfig,

View File

@@ -836,6 +836,12 @@ class AxolotlInputConfig(
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
expert_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
},
)
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={

View File

@@ -1386,6 +1386,14 @@ class ComplexValidationMixin:
self.tensor_parallel_size = 1
return self
@model_validator(mode="after")
def check_expert_parallel_size(self):
if not getattr(self, "expert_parallel_size", None):
self.expert_parallel_size = 1
elif self.expert_parallel_size < 1:
raise ValueError("expert_parallel_size must be >= 1")
return self
@model_validator(mode="after")
def check_context_parallel_size(self):
if self.sequence_parallel_degree and not self.context_parallel_size:

View File

@@ -0,0 +1,75 @@
"""
E2E smoke test for Llama 4 aux-loss-free routing via plugin
"""
import unittest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestLlama4MoeAuxFree(unittest.TestCase):
"""Smoke test to ensure aux-free plugin patches Llama 4 MoE correctly."""
@with_temp_dir
def test_llama4_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "yujiepan/llama-4-tiny-random",
"tokenizer_config": "yujiepan/llama-4-tiny-random",
"trust_remote_code": False,
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
assert patched is not None, (
"Llama 4 MoE layer was not patched by aux-free plugin"
)
assert patched._afb_bias.ndim == 1
assert patched._afb_counts.ndim == 1
check_model_output_exists(temp_dir, cfg)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,79 @@
"""
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestMoeAuxFree(unittest.TestCase):
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
@with_temp_dir
def test_mixtral_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# Inspect model modules for a patched MoE layer
patched = None
for m in model.modules():
if hasattr(m, "_afb_patched") and m._afb_patched is True:
patched = m
break
assert patched is not None, "No MoE layer patched by aux-free plugin"
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
# ensure counts buffer got reset by callback (best effort)
assert torch.all(patched._afb_counts == 0)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,91 @@
"""
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import gc
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
def _last_logged_loss(trainer) -> float | None:
# Scan log_history for the most recent entry with a 'loss' key
for entry in reversed(trainer.state.log_history):
if isinstance(entry, dict) and "loss" in entry:
return float(entry["loss"])
return None
class TestMoeAuxParity(unittest.TestCase):
@with_temp_dir
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
base_cfg = {
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 8,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
"seed": 42,
"logging_steps": 1,
}
# Baseline: aux-loss (gshard)
cfg0 = DictDefault(dict(base_cfg))
cfg0.output_dir = f"{temp_dir}/baseline"
cfg0 = validate_config(cfg0)
normalize_config(cfg0)
# baseline uses default aux-loss routing; no plugin registration
dataset_meta0 = load_datasets(cfg=cfg0)
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Release baseline resources before starting aux-free run
del model0, trainer0, dataset_meta0
gc.collect()
torch.cuda.empty_cache()
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"
cfg1.plugins = [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
]
cfg1.moe_balance_type = "noaux_tc"
cfg1.moe_update_rate = 0.01
cfg1.moe_update_momentum = 0.9
cfg1.moe_bias_cap = 2.0
prepare_plugins(cfg1)
cfg1 = validate_config(cfg1)
normalize_config(cfg1)
dataset_meta1 = load_datasets(cfg=cfg1)
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
loss1 = _last_logged_loss(trainer1)
assert loss1 is not None
# Assert aux-free loss is within 10% of aux-loss baseline
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"

View File

@@ -0,0 +1,76 @@
"""
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
"""
import unittest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestQwen3MoeAuxFree(unittest.TestCase):
@with_temp_dir
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# check that at least one sparse MoE block has been patched
found = False
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
m, "_afb_patched"
):
assert m._afb_patched is True
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
found = True
break
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,74 @@
"""
E2E smoke test for Ring 2.0 aux-loss-free routing via plugin
"""
import unittest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestRingMoeAuxFree(unittest.TestCase):
"""Smoke test to ensure aux-free plugin patches Ring Mini 2.0 correctly."""
@with_temp_dir
def test_ring_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "yujiepan/ring-tiny-random",
"tokenizer_config": "yujiepan/ring-tiny-random",
"trust_remote_code": True,
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin config
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
assert patched is not None, "Ring MoE layer was not patched by aux-free plugin"
assert patched._afb_bias.ndim == 1
assert patched._afb_counts.ndim == 1
check_model_output_exists(temp_dir, cfg)
if __name__ == "__main__":
unittest.main()

View File

@@ -12,7 +12,11 @@ from pathlib import Path
import torch
from packaging import version
from tbparse import SummaryReader
try:
from tbparse import SummaryReader
except ImportError: # pragma: no cover - optional dependency
SummaryReader = None
from axolotl.utils.dict import DictDefault
@@ -185,6 +189,10 @@ def check_tensorboard(
"""
helper function to parse and check tensorboard logs
"""
if SummaryReader is None:
raise unittest.SkipTest(
"tbparse is not installed; skipping tensorboard assertions"
)
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)

View File

@@ -0,0 +1,666 @@
import os
import sys
import tempfile
import unittest
from importlib import util as importlib_util
from pathlib import Path
from types import SimpleNamespace
import torch
import torch.distributed as dist
import torch.nn as nn
from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
def _cfg(**overrides):
defaults = dict(
moe_balance_type="noaux_tc",
moe_update_rate=0.1,
moe_update_momentum=0.9,
moe_bias_cap=2.0,
moe_afb_warmup_steps=0,
moe_bias_sync_group="world",
expert_parallel_size=1,
)
defaults.update(overrides)
return SimpleNamespace(**defaults)
def _load_bailing_modules():
repo_dir = snapshot_download(
repo_id="inclusionAI/Ring-mini-2.0",
allow_patterns=[
"configuration_bailing_moe_v2.py",
"modeling_bailing_moe_v2.py",
"__init__.py",
],
)
repo = Path(repo_dir)
config_path = repo / "configuration_bailing_moe_v2.py"
modeling_path = repo / "modeling_bailing_moe_v2.py"
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
if config_name not in sys.modules:
spec = importlib_util.spec_from_file_location(config_name, config_path)
module = importlib_util.module_from_spec(spec)
sys.modules[config_name] = module
sys.modules["configuration_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
config_module = sys.modules[config_name]
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
if modeling_name not in sys.modules:
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
module = importlib_util.module_from_spec(spec)
sys.modules[modeling_name] = module
sys.modules["modeling_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
modeling_module = sys.modules[modeling_name]
BailingMoeV2Config = config_module.BailingMoeV2Config
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
def _build_bailing_model():
BailingConfig, BailingBlock = _load_bailing_modules()
config = BailingConfig(
hidden_size=16,
intermediate_size=32,
moe_intermediate_size=32,
num_experts=4,
num_shared_experts=None,
num_experts_per_tok=2,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
)
block = BailingBlock(config)
class DummyModel(nn.Module):
def __init__(self, layer):
super().__init__()
self.block = layer
self.config = SimpleNamespace(model_type="bailing_moe")
def forward(self, hidden_states):
return self.block(hidden_states)
return DummyModel(block), block
def _build_llama4_model():
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
# Build config without __post_init__ validation (works around a
# huggingface_hub strict-dataclass type mismatch for layer_types).
config = object.__new__(__import__("transformers").Llama4TextConfig)
config.__dict__.update(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_attention_heads=2,
num_key_value_heads=2,
num_experts_per_tok=2,
num_hidden_layers=2,
hidden_act="silu",
layer_types=None,
)
layer = Llama4TextMoe(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="llama4")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _build_mixtral_model():
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
config = MixtralConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
)
layer = MixtralSparseMoeBlock(config)
layer.config = config
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="mixtral")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _build_qwen35_moe_model():
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeSparseMoeBlock,
)
config = Qwen3_5MoeTextConfig(
hidden_size=16,
moe_intermediate_size=32,
shared_expert_intermediate_size=32,
num_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
num_hidden_layers=2,
)
layer = Qwen3_5MoeSparseMoeBlock(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="qwen3_5_moe")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
if args is None:
args = SimpleNamespace(logging_steps=1)
if state is None:
state = SimpleNamespace(global_step=1, log_history=[])
if control is None:
control = SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
)
class DummyTrainer:
def __init__(self, state_obj, control_obj):
self.state = state_obj
self.control = control_obj
def log(self, logs):
output = dict(logs)
output["step"] = self.state.global_step
self.state.log_history.append(output)
self.control.should_log = True
dummy_trainer = DummyTrainer(state, control)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
callback.on_step_end(args=args, state=state, control=control)
return state, control
class TestAuxFreeAdapters(unittest.TestCase):
def test_bailing_adapter_updates_counts_and_bias(self):
model, block = _build_bailing_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(block, "_afb_bias"))
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
)
def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(layer, "_afb_bias"))
hidden = torch.randn(2, 4, layer.hidden_dim)
layer(hidden)
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_bias_warmup_respected(self):
model, block = _build_bailing_model()
cfg = _cfg(moe_afb_warmup_steps=2)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
def _step():
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
_run_callback(plugin, cfg)
# Warmup steps should leave bias untouched.
_step()
self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
_step()
self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
# Third step exceeds warmup -> bias should update.
_step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_patches_router_not_forward(self):
"""Verify that aux-free patches the router (gate) only, and the
v5 block forward signature (single tensor return) is preserved."""
model, layer = _build_mixtral_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Gate should be patched, not the block forward
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# v5 block forward returns a single tensor (not a tuple with logits)
hidden = torch.randn(2, 3, layer.config.hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
def test_mixtral_adapter_bias_affects_selection(self):
"""When bias is large for one expert, it should be selected more often."""
model, layer = _build_mixtral_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Set a large bias for expert 0 to force its selection
layer._afb_bias.zero_()
layer._afb_bias[0] = 10.0
hidden = torch.randn(2, 8, layer.config.hidden_size)
num_tokens = 2 * 8 # batch * seq
layer(hidden)
# With top_k=2, expert 0 should appear in every token's selection
# (once per token = num_tokens counts, not num_tokens * top_k)
counts = layer._afb_counts.clone()
self.assertEqual(
int(counts[0].item()),
num_tokens,
msg="Expert 0 should be selected for every token when heavily biased",
)
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
output includes shared expert contribution."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Gate should be patched
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# Shared expert should be unmodified
self.assertTrue(hasattr(layer, "shared_expert"))
self.assertTrue(hasattr(layer, "shared_expert_gate"))
# Forward should return a single tensor (shared + routed)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 3, hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
def test_qwen35_moe_adapter_bias_updates(self):
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 4, hidden_size)
layer(hidden)
# Bias should start at zero
self.assertTrue(
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
)
_run_callback(plugin, cfg)
# After callback: counts reset, EMA updated, bias updated
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_qwen35_moe_adapter_model_type_matching(self):
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
adapter = Qwen35MoeAdapter()
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
model_text = SimpleNamespace(
config=SimpleNamespace(model_type="qwen3_5_moe_text")
)
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
self.assertTrue(adapter.matches(model_moe))
self.assertTrue(adapter.matches(model_text))
self.assertFalse(adapter.matches(model_other))
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
self.skipTest(
"Cannot safely test deferred EP group resolution when a process group is already initialized"
)
model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertIsNotNone(plugin._shim)
self.assertIsNone(plugin._shim.ep_group)
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
dist.init_process_group(
backend="gloo", init_method=init_method, world_size=1, rank=0
)
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
_run_callback(
plugin,
cfg,
args=SimpleNamespace(logging_steps=1),
state=SimpleNamespace(global_step=1, log_history=[]),
control=SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
),
)
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
finally:
dist.destroy_process_group()
os.unlink(tmp_init.name)
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
layer(hidden)
args = SimpleNamespace(logging_steps=1)
state = SimpleNamespace(global_step=1, log_history=[])
control = SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
)
_run_callback(plugin, cfg, args=args, state=state, control=control)
self.assertTrue(control.should_log)
self.assertTrue(state.log_history)
telemetry = state.log_history[-1]
self.assertEqual(telemetry["step"], state.global_step)
self.assertIn("moe_afb/l0_load_min", telemetry)
self.assertIn("moe_afb/l0_load_max", telemetry)
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
def test_get_num_experts_v5_attribute_paths(self):
"""Verify get_num_experts works with v5 attribute layout where
num_experts is on gate/experts sub-modules, not the block."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
block = SimpleNamespace(
gate=SimpleNamespace(num_experts=8),
experts=SimpleNamespace(num_experts=8),
)
self.assertEqual(adapter.get_num_experts(block), 8)
# Also works when num_experts is directly on block
block2 = SimpleNamespace(num_experts=4)
self.assertEqual(adapter.get_num_experts(block2), 4)
class TestAuxFreeKernelComposition(unittest.TestCase):
"""Tests that aux-free bias composes correctly with kernel routing."""
def test_sonicmoe_softmax_routing_with_afb_bias(self):
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
# Build a mock MoE block with gate attributes
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, hidden_dim)
# Baseline: no bias
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
hidden, moe_block
)
self.assertEqual(scores_base.shape[0], T * top_k)
# Now register aux-free buffers and set heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
hidden, moe_block
)
# Expert 0 should be selected for every token
self.assertTrue(
(exp_biased == 0).any(),
"Expert 0 should appear in selections when heavily biased",
)
# Counts should have been accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
# Total counts should equal T * top_k
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_sonicmoe_routing_without_bias_unchanged(self):
"""Without _afb_bias, routing should produce identical results."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(4, hidden_dim)
# Without _afb_bias attribute
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
# With _afb_bias = zeros (should be equivalent)
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_counts = torch.zeros(num_experts)
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
torch.testing.assert_close(scores1, scores2)
torch.testing.assert_close(exp1, exp2)
@unittest.skipUnless(
importlib_util.find_spec("triton") is not None,
"triton not installed (required by scattermoe)",
)
def test_scattermoe_softmax_routing_with_afb_bias(self):
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
gate_weight = torch.randn(num_experts, hidden_dim)
base_gate = SimpleNamespace(
top_k=top_k,
num_experts=num_experts,
norm_topk_prob=True,
weight=gate_weight,
)
moe_block = SimpleNamespace()
hidden = torch.randn(T, hidden_dim)
# Baseline without bias
w_base, e_base, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# With heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
w_biased, e_biased, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# Expert 0 should appear in all selections
self.assertTrue((e_biased == 0).any())
# Counts accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_kernel_routing_skips_router_patch(self):
"""When a kernel backend has patched the block class, the adapter
should skip patching the router (buffers are still registered)."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Create a mock layer whose class has _original_forward (SonicMoE marker)
class PatchedBlock(nn.Module):
_original_forward = True # SonicMoE marker
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16) # placeholder
layer = PatchedBlock()
self.assertTrue(adapter.uses_kernel_routing(layer))
# Gate should NOT be patched (kernel handles routing)
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
def test_adapter_buffers_registered_even_with_kernel(self):
"""Even when kernel routing is active, aux-free buffers must be
registered on the MoE block so the kernel routing can find them."""
from axolotl.integrations.aux_free_router.adapters import (
LayerHandle,
MixtralAdapter,
)
from axolotl.integrations.aux_free_router.core import (
AuxFreeConfig,
AuxFreeShim,
AuxFreeState,
)
class PatchedBlock(nn.Module):
_original_forward = True
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16)
layer = PatchedBlock()
adapter = MixtralAdapter()
cfg = AuxFreeConfig()
state = AuxFreeState(
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
)
shim = AuxFreeShim(state=state)
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
adapter.prepare(layer, handle, shim)
# Buffers should be registered for kernel routing to use
self.assertTrue(hasattr(layer, "_afb_bias"))
self.assertTrue(hasattr(layer, "_afb_counts"))
self.assertTrue(hasattr(layer, "_afb_ema"))
# But gate should NOT be patched
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
if __name__ == "__main__":
unittest.main()