feat(moe-aux-loss-free): aux-free MoE plugin (Mixtral/Qwen3), EMA bias updates, config keys; E2E smoke + parity tests

This commit is contained in:
lhl
2025-10-27 00:14:38 +09:00
committed by Wing Lian
parent 5b2e3f00ce
commit 3e4688289c
9 changed files with 697 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
# Aux-Loss-Free MoE Router Plugin
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
Summary
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
Enable
- Add the plugin to your YAML, then set the aux-free toggle:
plugins:
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
moe_balance_type: noaux_tc
moe_update_rate: 0.01 # default if unset
moe_update_momentum: 0.9 # default if unset
moe_bias_cap: 2.0 # default if unset
moe_afb_warmup_steps: 100 # optional
moe_bias_sync_group: world # or 'ep' if expert-parallel is configured
Config keys
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
- moe_update_rate: bias update rate (gamma). Default: 0.01.
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
Compatibility
- Targeted families: Mixtral, Qwen3-MoE. Jamba optional.
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
Notes
- If you also enable Ligers aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
- Telemetry: future updates will log per-expert loads and bias magnitudes.

View File

@@ -0,0 +1,2 @@
"""Aux-loss-free (AFB) MoE router integration package."""

View File

@@ -0,0 +1,172 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from torch import nn
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
from .core import AuxFreeShim
LOG = get_logger(__name__)
@dataclass
class LayerHandle:
layer: nn.Module
layer_idx: int
num_experts: int
top_k: int
class BaseMoEAdapter:
"""Base adapter that discovers MoE layers and wraps their forward.
Concrete adapters should implement discovery and per-layer attribute extraction.
"""
family: str = "generic"
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
return False
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
return []
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts"))
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
# Best-effort: zero router aux loss coef if present
if hasattr(model_or_layer, "router_aux_loss_coef"):
try:
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
except Exception: # pragma: no cover - non-critical
pass
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled.
Note: Forward rebind happens in concrete adapters once we implement full routing.
For now, we only attach buffers as placeholders to minimize disturbance.
"""
device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
if not hasattr(moe_layer, "_afb_counts"):
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
if not hasattr(moe_layer, "_afb_ema"):
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
"""Replace the layer's forward with an aux-free gating version.
Assumes the layer exposes attributes:
- gate: linear router projecting hidden to num_experts
- num_experts: int
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
"""
if getattr(moe_layer, "_afb_patched", False):
return
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
return
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
# hidden_states: (B, T, H)
bsz, seqlen, hdim = hidden_states.shape
hs = hidden_states.view(-1, hdim)
logits = self.gate(hs)
# selection uses biased logits; weights from unbiased logits
bias = getattr(self, "_afb_bias")
top_k = int(getattr(self, "_afb_top_k", 2))
biased = logits + bias # broadcast over tokens
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1)
weights = weights.to(hs.dtype)
# accumulate counts for bias update callback
flat_idx = topk_idx.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
# dispatch tokens to experts
hs_rep = hs.repeat_interleave(top_k, dim=0)
y = torch.empty_like(hs_rep)
for eid in range(int(self.num_experts)):
mask = flat_idx == eid
if mask.any():
y[mask] = self.experts[eid](hs_rep[mask])
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
out = y.view(bsz, seqlen, hdim)
return (out, logits)
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class MixtralAdapter(BaseMoEAdapter):
family = "mixtral"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m
class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
Returns a list of layer handles for later routing patching and updates.
"""
handles: list[LayerHandle] = []
adapter: Optional[BaseMoEAdapter] = None
for a in adapters:
if a.matches(model):
adapter = a
break
if adapter is None:
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
return handles
# disable aux loss at model level if possible
adapter.disable_aux_loss(getattr(model, "config", model))
idx = 0
for layer in adapter.find_moe_layers(model):
try:
top_k = adapter.get_top_k(layer)
nE = adapter.get_num_experts(layer)
except Exception:
continue
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
adapter.prepare(layer, handle, shim)
handles.append(handle)
idx += 1
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
return handles

View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
@dataclass
class AuxFreeConfig:
rate: float = 0.01
momentum: float = 0.9
bias_cap: float = 2.0
warmup_steps: int = 0
sync_group: str = "world" # or "ep"
class AuxFreeState:
"""Holds per-layer bias and EMA load buffers."""
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.cfg = cfg
self.steps = 0
class AuxFreeShim:
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
def __init__(self, state: AuxFreeState, ep_group: Optional[dist.ProcessGroup] = None):
self.state = state
self.ep_group = ep_group
@torch.no_grad()
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights
@torch.no_grad()
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
if not dist.is_available() or not dist.is_initialized():
return counts
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
return counts
@torch.no_grad()
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
cfg = self.state.cfg
self.state.steps += 1
if self.state.steps <= cfg.warmup_steps:
return
nE = step_counts.numel()
if tokens_seen <= 0:
return
freq = step_counts.float() / float(tokens_seen)
ema = self.state.ema_load[layer_idx]
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
# optional mean-centering to keep sum(bias) ~ 0
delta = delta - delta.mean()
bias = self.state.bias[layer_idx]
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)

View File

@@ -0,0 +1,133 @@
"""Aux-loss-free MoE Router Plugin for Axolotl.
This plugin wires an aux-free gating option into compatible MoE models using
unbiased logits for mixture weights and per-expert biases for top-k selection.
"""
from __future__ import annotations
from typing import Optional
import torch
from transformers.trainer_callback import TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_distributed
from axolotl.utils.logging import get_logger
from .adapters import BaseMoEAdapter, MixtralAdapter, Qwen3Adapter, discover_and_prepare_layers
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
LOG = get_logger(__name__)
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
"""Post-step callback to update aux-free biases from accumulated expert counts.
Note: The current revision expects per-layer counts to be accumulated on each
MoE layer as a buffer named `_afb_counts` during forward (to be added with
routing patches in a follow-up).
"""
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
self.shim = shim
self.layer_modules = layer_modules
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
# Iterate prepared MoE layers and apply the bias update rule.
cfg = self.shim.state.cfg
for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
continue
counts = getattr(layer, "_afb_counts")
if counts is None:
continue
counts = counts.to(counts.device)
counts = self.shim.all_reduce_counts(counts)
tokens_seen = int(counts.sum().item())
# local layer-state EMA and bias update
if tokens_seen > 0:
freq = counts.float() / float(tokens_seen)
ema = getattr(layer, "_afb_ema")
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
nE = counts.numel()
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
delta = delta - delta.mean()
bias = getattr(layer, "_afb_bias")
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
# reset step counts
counts.zero_()
return control
class AuxFreeMoEPlugin(BasePlugin):
"""Plugin that enables aux-loss-free routing when configured."""
def __init__(self):
super().__init__()
self._handles: list = []
self._shim: Optional[AuxFreeShim] = None
def post_model_build(self, cfg, model):
# Enable only when explicitly requested
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return
# Be conservative — skip known native aux-free families
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
"deepseek_v3",
"glm4_moe",
)
if native_auxfree:
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
return
# Build aux-free state and shim
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
momentum = (
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
)
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
af_cfg = AuxFreeConfig(
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
)
# Discover layers to count the number and experts for state sizing
adapters: list[BaseMoEAdapter] = [MixtralAdapter(), Qwen3Adapter()]
# For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0
n_experts = None
for m in model.modules():
n_layers += 1 # upper bound — we will re-use bias slots sparsely
device = next(model.parameters(), torch.tensor(0)).device
if n_layers <= 0:
n_layers = 1
if n_experts is None:
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
self._shim = AuxFreeShim(state=state, ep_group=None)
# Discover and prepare layers (attach per-layer buffers)
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
LOG.info(
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
)
def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return []
if self._shim is None:
return []
# gather concrete layer modules from handles
layers = [h.layer for h in self._handles]
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
LOG.info("AuxFreeMoE: registering post-step bias update callback")
return [cb]

View File

@@ -758,6 +758,44 @@ class AxolotlInputConfig(
llama4_linearized_experts: bool | None = None
# MoE aux-loss-free (AFB) toggles
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
default=None,
json_schema_extra={
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
},
)
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. If unset, plugin default is 0.01.",
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). If unset, plugin default is 0.9.",
},
)
moe_bias_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
},
)
moe_afb_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
},
)
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
default=None,
json_schema_extra={
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
},
)
deepspeed: str | dict[str, Any] | None = Field(
default=None,
json_schema_extra={

View File

@@ -0,0 +1,79 @@
"""
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestMoeAuxFree(unittest.TestCase):
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
@with_temp_dir
def test_mixtral_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
prepare_plugins(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# Inspect model modules for a patched MoE layer
patched = None
for m in model.modules():
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
patched = m
break
assert patched is not None, "No MoE layer patched by aux-free plugin"
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
# ensure counts buffer got reset by callback (best effort)
assert torch.all(patched._afb_counts == 0)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,83 @@
"""
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import unittest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
def _last_logged_loss(trainer) -> float | None:
# Scan log_history for the most recent entry with a 'loss' key
for entry in reversed(trainer.state.log_history):
if isinstance(entry, dict) and "loss" in entry:
return float(entry["loss"])
return None
class TestMoeAuxParity(unittest.TestCase):
@with_temp_dir
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
base_cfg = {
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 8,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
"seed": 42,
"logging_steps": 1,
}
# Baseline: aux-loss (gshard)
cfg0 = DictDefault(dict(base_cfg))
cfg0.output_dir = f"{temp_dir}/baseline"
cfg0 = validate_config(cfg0)
normalize_config(cfg0)
# baseline uses default aux-loss routing; no plugin registration
dataset_meta0 = load_datasets(cfg=cfg0)
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"
cfg1.plugins = [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
]
cfg1.moe_balance_type = "noaux_tc"
cfg1.moe_update_rate = 0.01
cfg1.moe_update_momentum = 0.9
cfg1.moe_bias_cap = 2.0
cfg1 = validate_config(cfg1)
normalize_config(cfg1)
prepare_plugins(cfg1)
dataset_meta1 = load_datasets(cfg=cfg1)
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
loss1 = _last_logged_loss(trainer1)
assert loss1 is not None
# Assert aux-free loss is within 10% of aux-loss baseline
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"

View File

@@ -0,0 +1,76 @@
"""
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestQwen3MoeAuxFree(unittest.TestCase):
@with_temp_dir
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
prepare_plugins(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# check that at least one sparse MoE block has been patched
found = False
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
assert m._afb_patched is True
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
found = True
break
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
check_model_output_exists(temp_dir, cfg)