diff --git a/src/axolotl/integrations/aux_free_router/README.md b/src/axolotl/integrations/aux_free_router/README.md new file mode 100644 index 000000000..7776ff94e --- /dev/null +++ b/src/axolotl/integrations/aux_free_router/README.md @@ -0,0 +1,38 @@ +# Aux-Loss-Free MoE Router Plugin + +This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code. + +Summary +- Bias only affects expert selection (top-k); mixture weights come from unbiased logits. +- Per-expert token loads are accumulated on device and reduced across DP or EP groups. +- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads. +- Existing aux loss is disabled when aux-free is enabled to avoid double signals. + +Enable +- Add the plugin to your YAML, then set the aux-free toggle: + + plugins: + - axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin + + moe_balance_type: noaux_tc + moe_update_rate: 0.01 # default if unset + moe_update_momentum: 0.9 # default if unset + moe_bias_cap: 2.0 # default if unset + moe_afb_warmup_steps: 100 # optional + moe_bias_sync_group: world # or 'ep' if expert-parallel is configured + +Config keys +- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native. +- moe_update_rate: bias update rate (gamma). Default: 0.01. +- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9. +- moe_bias_cap: absolute clamp for bias. Default: 2.0. +- moe_afb_warmup_steps: delay before applying updates. Default: 0. +- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world. + +Compatibility +- Targeted families: Mixtral, Qwen3-MoE. Jamba optional. +- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future. + +Notes +- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on. +- Telemetry: future updates will log per-expert loads and bias magnitudes. diff --git a/src/axolotl/integrations/aux_free_router/__init__.py b/src/axolotl/integrations/aux_free_router/__init__.py new file mode 100644 index 000000000..b3f78049b --- /dev/null +++ b/src/axolotl/integrations/aux_free_router/__init__.py @@ -0,0 +1,2 @@ +"""Aux-loss-free (AFB) MoE router integration package.""" + diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py new file mode 100644 index 000000000..014c3b80d --- /dev/null +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Optional + +import torch +from torch import nn +import torch.nn.functional as F + +from axolotl.utils.logging import get_logger + +from .core import AuxFreeShim + +LOG = get_logger(__name__) + + +@dataclass +class LayerHandle: + layer: nn.Module + layer_idx: int + num_experts: int + top_k: int + + +class BaseMoEAdapter: + """Base adapter that discovers MoE layers and wraps their forward. + + Concrete adapters should implement discovery and per-layer attribute extraction. + """ + + family: str = "generic" + + def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim + return False + + def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover + return [] + + def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover + return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2))) + + def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover + return int(getattr(moe_layer, "num_experts")) + + def disable_aux_loss(self, model_or_layer: nn.Module) -> None: + # Best-effort: zero router aux loss coef if present + if hasattr(model_or_layer, "router_aux_loss_coef"): + try: + setattr(model_or_layer, "router_aux_loss_coef", 0.0) + except Exception: # pragma: no cover - non-critical + pass + + def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + """Attach per-layer buffers and mark as aux-free enabled. + + Note: Forward rebind happens in concrete adapters once we implement full routing. + For now, we only attach buffers as placeholders to minimize disturbance. + """ + device = next(moe_layer.parameters(), torch.tensor(0)).device + if not hasattr(moe_layer, "_afb_bias"): + moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device)) + if not hasattr(moe_layer, "_afb_counts"): + moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device)) + if not hasattr(moe_layer, "_afb_ema"): + moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device)) + moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined] + moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined] + self._patch_forward_with_aux_free(moe_layer) + + def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None: + """Replace the layer's forward with an aux-free gating version. + + Assumes the layer exposes attributes: + - gate: linear router projecting hidden to num_experts + - num_experts: int + - experts: iterable of expert modules taking (tokens, H) -> (tokens, H) + """ + if getattr(moe_layer, "_afb_patched", False): + return + + if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"): + LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch") + return + + def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef] + # hidden_states: (B, T, H) + bsz, seqlen, hdim = hidden_states.shape + hs = hidden_states.view(-1, hdim) + logits = self.gate(hs) + # selection uses biased logits; weights from unbiased logits + bias = getattr(self, "_afb_bias") + top_k = int(getattr(self, "_afb_top_k", 2)) + biased = logits + bias # broadcast over tokens + topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False) + chosen_logits = torch.gather(logits, -1, topk_idx) + weights = torch.softmax(chosen_logits.float(), dim=-1) + weights = weights.to(hs.dtype) + + # accumulate counts for bias update callback + flat_idx = topk_idx.reshape(-1) + counts = torch.bincount(flat_idx, minlength=int(self.num_experts)) + getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype)) + + # dispatch tokens to experts + hs_rep = hs.repeat_interleave(top_k, dim=0) + y = torch.empty_like(hs_rep) + for eid in range(int(self.num_experts)): + mask = flat_idx == eid + if mask.any(): + y[mask] = self.experts[eid](hs_rep[mask]) + + y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1) + out = y.view(bsz, seqlen, hdim) + return (out, logits) + + moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined] + setattr(moe_layer, "_afb_patched", True) + + +class MixtralAdapter(BaseMoEAdapter): + family = "mixtral" + + def matches(self, model: nn.Module) -> bool: + return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral" + + def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: + for m in model.modules(): + if m.__class__.__name__.endswith("SparseMoeBlock"): + yield m + + +class Qwen3Adapter(MixtralAdapter): + family = "qwen3_moe" + + def matches(self, model: nn.Module) -> bool: + return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe") + + +def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]: + """Discover MoE layers using the first matching adapter and attach per-layer buffers. + + Returns a list of layer handles for later routing patching and updates. + """ + handles: list[LayerHandle] = [] + adapter: Optional[BaseMoEAdapter] = None + for a in adapters: + if a.matches(model): + adapter = a + break + + if adapter is None: + LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing") + return handles + + # disable aux loss at model level if possible + adapter.disable_aux_loss(getattr(model, "config", model)) + + idx = 0 + for layer in adapter.find_moe_layers(model): + try: + top_k = adapter.get_top_k(layer) + nE = adapter.get_num_experts(layer) + except Exception: + continue + + handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k) + adapter.prepare(layer, handle, shim) + handles.append(handle) + idx += 1 + + LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing") + return handles diff --git a/src/axolotl/integrations/aux_free_router/core.py b/src/axolotl/integrations/aux_free_router/core.py new file mode 100644 index 000000000..a673a9856 --- /dev/null +++ b/src/axolotl/integrations/aux_free_router/core.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist + + +@dataclass +class AuxFreeConfig: + rate: float = 0.01 + momentum: float = 0.9 + bias_cap: float = 2.0 + warmup_steps: int = 0 + sync_group: str = "world" # or "ep" + + +class AuxFreeState: + """Holds per-layer bias and EMA load buffers.""" + + def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig): + self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)] + self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)] + self.cfg = cfg + self.steps = 0 + + +class AuxFreeShim: + """Model-agnostic shim for aux-loss-free expert selection and bias updates.""" + + def __init__(self, state: AuxFreeState, ep_group: Optional[dist.ProcessGroup] = None): + self.state = state + self.ep_group = ep_group + + @torch.no_grad() + def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + """Returns (topk_indices, weights) using biased selection and unbiased weights.""" + b = self.state.bias[layer_idx] + biased = logits + b # bias is a buffer + topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1) + chosen_logits = torch.gather(logits, -1, topk_idx) + weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype) + return topk_idx, weights + + @torch.no_grad() + def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor: + if not dist.is_available() or not dist.is_initialized(): + return counts + group = self.ep_group if self.ep_group is not None else dist.group.WORLD + dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group) + return counts + + @torch.no_grad() + def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int): + """Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering.""" + cfg = self.state.cfg + self.state.steps += 1 + if self.state.steps <= cfg.warmup_steps: + return + + nE = step_counts.numel() + if tokens_seen <= 0: + return + freq = step_counts.float() / float(tokens_seen) + ema = self.state.ema_load[layer_idx] + ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq) + target = 1.0 / float(nE) + delta = cfg.rate * (target - ema) + # optional mean-centering to keep sum(bias) ~ 0 + delta = delta - delta.mean() + bias = self.state.bias[layer_idx] + bias.add_(delta) + if cfg.bias_cap is not None and cfg.bias_cap > 0: + bias.clamp_(-cfg.bias_cap, cfg.bias_cap) + diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py new file mode 100644 index 000000000..3cd6b4f6c --- /dev/null +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -0,0 +1,133 @@ +"""Aux-loss-free MoE Router Plugin for Axolotl. + +This plugin wires an aux-free gating option into compatible MoE models using +unbiased logits for mixture weights and per-expert biases for top-k selection. +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from transformers.trainer_callback import TrainerCallback + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.distributed import is_distributed +from axolotl.utils.logging import get_logger + +from .adapters import BaseMoEAdapter, MixtralAdapter, Qwen3Adapter, discover_and_prepare_layers +from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState + +LOG = get_logger(__name__) + + +class MoeAuxFreeBiasUpdateCallback(TrainerCallback): + """Post-step callback to update aux-free biases from accumulated expert counts. + + Note: The current revision expects per-layer counts to be accumulated on each + MoE layer as a buffer named `_afb_counts` during forward (to be added with + routing patches in a follow-up). + """ + + def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]): + self.shim = shim + self.layer_modules = layer_modules + + def on_step_end(self, args, state, control, **kwargs): # noqa: D401 + # Iterate prepared MoE layers and apply the bias update rule. + cfg = self.shim.state.cfg + for layer in self.layer_modules: + if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"): + continue + counts = getattr(layer, "_afb_counts") + if counts is None: + continue + counts = counts.to(counts.device) + counts = self.shim.all_reduce_counts(counts) + tokens_seen = int(counts.sum().item()) + # local layer-state EMA and bias update + if tokens_seen > 0: + freq = counts.float() / float(tokens_seen) + ema = getattr(layer, "_afb_ema") + ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq) + nE = counts.numel() + target = 1.0 / float(nE) + delta = cfg.rate * (target - ema) + delta = delta - delta.mean() + bias = getattr(layer, "_afb_bias") + bias.add_(delta) + if cfg.bias_cap is not None and cfg.bias_cap > 0: + bias.clamp_(-cfg.bias_cap, cfg.bias_cap) + # reset step counts + counts.zero_() + return control + + +class AuxFreeMoEPlugin(BasePlugin): + """Plugin that enables aux-loss-free routing when configured.""" + + def __init__(self): + super().__init__() + self._handles: list = [] + self._shim: Optional[AuxFreeShim] = None + + def post_model_build(self, cfg, model): + # Enable only when explicitly requested + if getattr(cfg, "moe_balance_type", None) != "noaux_tc": + return + + # Be conservative — skip known native aux-free families + native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in ( + "deepseek_v3", + "glm4_moe", + ) + if native_auxfree: + LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching") + return + + # Build aux-free state and shim + rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01 + momentum = ( + cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9 + ) + bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0 + warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0 + sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world" + af_cfg = AuxFreeConfig( + rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group + ) + + # Discover layers to count the number and experts for state sizing + adapters: list[BaseMoEAdapter] = [MixtralAdapter(), Qwen3Adapter()] + + # For initial state sizing, we conservatively assume the first discovered layer defines nE + n_layers = 0 + n_experts = None + for m in model.modules(): + n_layers += 1 # upper bound — we will re-use bias slots sparsely + device = next(model.parameters(), torch.tensor(0)).device + if n_layers <= 0: + n_layers = 1 + if n_experts is None: + # we'll set a minimal placeholder; prepare() will conceptually use module buffers instead + n_experts = 2 + state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg) + self._shim = AuxFreeShim(state=state, ep_group=None) + + # Discover and prepare layers (attach per-layer buffers) + self._handles = discover_and_prepare_layers(model, adapters, self._shim) + + LOG.info( + f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}" + ) + + def add_callbacks_post_trainer(self, cfg, trainer): + if getattr(cfg, "moe_balance_type", None) != "noaux_tc": + return [] + if self._shim is None: + return [] + # gather concrete layer modules from handles + layers = [h.layer for h in self._handles] + cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers) + LOG.info("AuxFreeMoE: registering post-step bias update callback") + return [cb] diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 67dea4958..bda30cf15 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -758,6 +758,44 @@ class AxolotlInputConfig( llama4_linearized_experts: bool | None = None + # MoE aux-loss-free (AFB) toggles + moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field( + default=None, + json_schema_extra={ + "description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.", + }, + ) + moe_update_rate: float | None = Field( + default=None, + json_schema_extra={ + "description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. If unset, plugin default is 0.01.", + }, + ) + moe_update_momentum: float | None = Field( + default=None, + json_schema_extra={ + "description": "EMA momentum for expert load smoothing (0–1). If unset, plugin default is 0.9.", + }, + ) + moe_bias_cap: float | None = Field( + default=None, + json_schema_extra={ + "description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.", + }, + ) + moe_afb_warmup_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.", + }, + ) + moe_bias_sync_group: Literal["world", "ep"] | None = Field( + default=None, + json_schema_extra={ + "description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.", + }, + ) + deepspeed: str | dict[str, Any] | None = Field( default=None, json_schema_extra={ diff --git a/tests/e2e/test_moe_aux_free.py b/tests/e2e/test_moe_aux_free.py new file mode 100644 index 000000000..266fcc633 --- /dev/null +++ b/tests/e2e/test_moe_aux_free.py @@ -0,0 +1,79 @@ +""" +E2E smoke tests for Aux-Loss-Free MoE routing via plugin +""" + +import unittest + +import torch + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config, prepare_plugins +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestMoeAuxFree(unittest.TestCase): + """Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny.""" + + @with_temp_dir + def test_mixtral_aux_free_smoke(self, temp_dir): + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", + "flash_attention": False, + "sequence_len": 512, + "bf16": False, + "fp16": False, + "val_set_size": 0.02, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_steps": 0, + "eval_steps": 0, + "save_first_step": False, + # Aux-free plugin and toggles + "plugins": [ + "axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin", + ], + "moe_balance_type": "noaux_tc", + "moe_update_rate": 0.01, + "moe_update_momentum": 0.9, + "moe_bias_cap": 2.0, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + prepare_plugins(cfg) + dataset_meta = load_datasets(cfg=cfg) + + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) + + # Inspect model modules for a patched MoE layer + patched = None + for m in model.modules(): + if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True: + patched = m + break + assert patched is not None, "No MoE layer patched by aux-free plugin" + assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1 + assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1 + # ensure counts buffer got reset by callback (best effort) + assert torch.all(patched._afb_counts == 0) + + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_moe_aux_parity.py b/tests/e2e/test_moe_aux_parity.py new file mode 100644 index 000000000..11b5a920b --- /dev/null +++ b/tests/e2e/test_moe_aux_parity.py @@ -0,0 +1,83 @@ +""" +Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny. +Checks that aux-free training loss does not degrade beyond a small tolerance. +""" + +import unittest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config, prepare_plugins +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + + +def _last_logged_loss(trainer) -> float | None: + # Scan log_history for the most recent entry with a 'loss' key + for entry in reversed(trainer.state.log_history): + if isinstance(entry, dict) and "loss" in entry: + return float(entry["loss"]) + return None + + +class TestMoeAuxParity(unittest.TestCase): + @with_temp_dir + def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir): + base_cfg = { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", + "flash_attention": False, + "sequence_len": 512, + "bf16": False, + "fp16": False, + "val_set_size": 0.02, + "special_tokens": {}, + "datasets": [ + {"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 8, + "save_steps": 0, + "eval_steps": 0, + "save_first_step": False, + "seed": 42, + "logging_steps": 1, + } + + # Baseline: aux-loss (gshard) + cfg0 = DictDefault(dict(base_cfg)) + cfg0.output_dir = f"{temp_dir}/baseline" + cfg0 = validate_config(cfg0) + normalize_config(cfg0) + # baseline uses default aux-loss routing; no plugin registration + dataset_meta0 = load_datasets(cfg=cfg0) + model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0) + loss0 = _last_logged_loss(trainer0) + assert loss0 is not None + + # Aux-free: plugin + noaux_tc + cfg1 = DictDefault(dict(base_cfg)) + cfg1.output_dir = f"{temp_dir}/auxfree" + cfg1.plugins = [ + "axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin", + ] + cfg1.moe_balance_type = "noaux_tc" + cfg1.moe_update_rate = 0.01 + cfg1.moe_update_momentum = 0.9 + cfg1.moe_bias_cap = 2.0 + cfg1 = validate_config(cfg1) + normalize_config(cfg1) + prepare_plugins(cfg1) + dataset_meta1 = load_datasets(cfg=cfg1) + model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1) + loss1 = _last_logged_loss(trainer1) + assert loss1 is not None + + # Assert aux-free loss is within 10% of aux-loss baseline + assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}" diff --git a/tests/e2e/test_qwen3_moe_aux_free.py b/tests/e2e/test_qwen3_moe_aux_free.py new file mode 100644 index 000000000..204ec4ad3 --- /dev/null +++ b/tests/e2e/test_qwen3_moe_aux_free.py @@ -0,0 +1,76 @@ +""" +E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny +""" + +import unittest + +import torch + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config, prepare_plugins +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestQwen3MoeAuxFree(unittest.TestCase): + @with_temp_dir + def test_qwen3_moe_aux_free_smoke(self, temp_dir): + cfg = DictDefault( + { + "base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM", + "tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM", + "flash_attention": False, + "sequence_len": 512, + "bf16": False, + "fp16": False, + "val_set_size": 0.02, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_steps": 0, + "eval_steps": 0, + "save_first_step": False, + # Aux-free plugin and toggles + "plugins": [ + "axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin", + ], + "moe_balance_type": "noaux_tc", + "moe_update_rate": 0.01, + "moe_update_momentum": 0.9, + "moe_bias_cap": 2.0, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + prepare_plugins(cfg) + dataset_meta = load_datasets(cfg=cfg) + + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) + + # check that at least one sparse MoE block has been patched + found = False + for m in model.modules(): + if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"): + assert m._afb_patched is True + assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1 + assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1 + found = True + break + assert found, "No Qwen3-MoE sparse block patched by aux-free plugin" + + check_model_output_exists(temp_dir, cfg)