diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py index ac9ed1851..0df987fe3 100644 --- a/src/axolotl/integrations/aux_free_router/adapters.py +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -1,11 +1,19 @@ +"""Architecture-specific adapters for aux-loss-free MoE routing. + +Each adapter discovers MoE layers for a model family and patches only the +router/gate to inject per-expert bias into expert selection while keeping +mixture weights from unbiased logits. Expert dispatch is left untouched so +the patching composes with any expert backend (eager, ScatterMoE, SonicMoE). +""" + from __future__ import annotations from dataclasses import dataclass from typing import Iterable, Optional import torch -from torch import nn import torch.nn.functional as F +from torch import nn from axolotl.utils.logging import get_logger @@ -23,9 +31,10 @@ class LayerHandle: class BaseMoEAdapter: - """Base adapter that discovers MoE layers and wraps their forward. + """Base adapter that discovers MoE layers and patches their routing. - Concrete adapters should implement discovery and per-layer attribute extraction. + Concrete adapters implement discovery, attribute extraction, and + architecture-specific router patching. """ family: str = "generic" @@ -33,158 +42,191 @@ class BaseMoEAdapter: def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim return False - def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover + def find_moe_layers( + self, model: nn.Module + ) -> Iterable[nn.Module]: # pragma: no cover return [] - def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover - return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2))) + def get_top_k(self, moe_layer: nn.Module) -> int: + """Resolve top_k from the MoE layer, checking common attribute paths.""" + for attr_path in [ + ("top_k",), + ("num_experts_per_tok",), + ("gate", "top_k"), + ("router", "top_k"), + ]: + obj: object = moe_layer + for attr in attr_path: + obj = getattr(obj, attr, None) + if obj is None: + break + if isinstance(obj, int): + return obj + return 2 - def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover - return int(getattr(moe_layer, "num_experts")) + def get_num_experts(self, moe_layer: nn.Module) -> int: + """Resolve num_experts from the MoE layer, checking common attribute paths.""" + for attr_path in [ + ("num_experts",), + ("num_local_experts",), + ("gate", "num_experts"), + ("router", "num_experts"), + ("experts", "num_experts"), + ]: + obj: object = moe_layer + for attr in attr_path: + obj = getattr(obj, attr, None) + if obj is None: + break + if isinstance(obj, int): + return obj + raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}") def disable_aux_loss(self, model_or_layer: nn.Module) -> None: # Best-effort: zero router aux loss coef if present if hasattr(model_or_layer, "router_aux_loss_coef"): try: - setattr(model_or_layer, "router_aux_loss_coef", 0.0) + model_or_layer.router_aux_loss_coef = 0.0 except Exception: # pragma: no cover - non-critical pass - def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + def _register_aux_buffers( + self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim + ) -> None: device = next(moe_layer.parameters(), torch.tensor(0)).device if not hasattr(moe_layer, "_afb_bias"): - moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device)) + moe_layer.register_buffer( + "_afb_bias", torch.zeros(handle.num_experts, device=device) + ) if not hasattr(moe_layer, "_afb_counts"): - moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device)) + moe_layer.register_buffer( + "_afb_counts", torch.zeros(handle.num_experts, device=device) + ) if not hasattr(moe_layer, "_afb_ema"): - moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device)) + moe_layer.register_buffer( + "_afb_ema", torch.zeros(handle.num_experts, device=device) + ) moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined] moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined] shim.register_layer_buffers(handle.layer_idx, moe_layer) - def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: - """Attach per-layer buffers and mark as aux-free enabled.""" + def prepare( + self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim + ) -> None: + """Attach per-layer buffers. Subclasses override to also patch routing.""" self._register_aux_buffers(moe_layer, handle, shim) - self._patch_forward_with_aux_free(moe_layer) - def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None: - """Replace the layer's forward with an aux-free gating version. - - Assumes the layer exposes attributes: - - gate: linear router projecting hidden to num_experts - - num_experts: int - - experts: iterable of expert modules taking (tokens, H) -> (tokens, H) - """ - if getattr(moe_layer, "_afb_patched", False): - return - - if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"): - LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch") - return - - def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef] - # hidden_states: (B, T, H) - bsz, seqlen, hdim = hidden_states.shape - hs = hidden_states.view(-1, hdim) - logits = self.gate(hs) - # selection uses biased logits; weights from unbiased logits - bias = getattr(self, "_afb_bias") - top_k = int(getattr(self, "_afb_top_k", 2)) - biased = logits + bias # broadcast over tokens - topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False) - chosen_logits = torch.gather(logits, -1, topk_idx) - weights = torch.softmax(chosen_logits.float(), dim=-1) - weights = weights.to(hs.dtype) - - # accumulate counts for bias update callback - flat_idx = topk_idx.reshape(-1) - counts = torch.bincount(flat_idx, minlength=int(self.num_experts)) - getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype)) - - # dispatch tokens to experts - hs_rep = hs.repeat_interleave(top_k, dim=0) - y = torch.empty_like(hs_rep) - for eid in range(int(self.num_experts)): - mask = flat_idx == eid - if mask.any(): - y[mask] = self.experts[eid](hs_rep[mask]) - - y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1) - out = y.view(bsz, seqlen, hdim) - return (out, logits) - - moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined] - setattr(moe_layer, "_afb_patched", True) + def uses_kernel_routing(self, moe_layer: nn.Module) -> bool: + """Return True when a kernel backend (SonicMoE / ScatterMoE) has + already replaced the block forward, meaning the routing is handled + inside the kernel forward and we should NOT patch the router.""" + cls = type(moe_layer) + # SonicMoE stores the original forward when it patches a class. + if hasattr(cls, "_original_forward"): + return True + # ScatterMoE replaces via kernels library; check for the marker. + if hasattr(cls, "_kernel_forward"): + return True + return False class MixtralAdapter(BaseMoEAdapter): + """Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk + routing so that biased logits drive expert *selection* while unbiased + softmax scores drive mixture *weights*. + + Works with transformers v5 where experts are fused 3D tensors and + the router is ``MixtralTopKRouter`` (returns a 3-tuple). + """ + family = "mixtral" def matches(self, model: nn.Module) -> bool: - return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral" - - def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: - self._register_aux_buffers(moe_layer, handle, shim) - self._patch_mixtral_forward(moe_layer, shim) + return ( + getattr(getattr(model, "config", object()), "model_type", "") == "mixtral" + ) def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: for m in model.modules(): if m.__class__.__name__.endswith("SparseMoeBlock"): yield m - def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None: - if getattr(moe_layer, "_afb_patched", False): - return - - shim_ref = shim - - def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef] - batch_size, sequence_length, hidden_dim = hidden_states.shape - if self.training and getattr(self, "jitter_noise", 0) > 0: - hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_( - 1.0 - self.jitter_noise, 1.0 + self.jitter_noise - ) - flat_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(flat_states) - - layer_idx = int(getattr(self, "_afb_layer_idx", 0)) - top_k = int(getattr(self, "_afb_top_k", self.top_k)) - selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k) - routing_weights = routing_weights.to(flat_states.dtype) - - flat_idx = selected_experts.reshape(-1) - counts = torch.bincount(flat_idx, minlength=int(self.num_experts)) - self._afb_counts.add_(counts.to(self._afb_counts.dtype)) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=flat_states.dtype, - device=flat_states.device, + def prepare( + self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim + ) -> None: + self._register_aux_buffers(moe_layer, handle, shim) + if not self.uses_kernel_routing(moe_layer): + self._patch_router(moe_layer) + else: + LOG.info( + "AuxFreeMoE: kernel backend detected on %s; " + "skipping router patch (kernel routing handles bias)", + type(moe_layer).__name__, ) - expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx_tensor in expert_hit: - expert_idx = int(expert_idx_tensor.squeeze().item()) - expert_layer = self.experts[expert_idx] - mask = expert_mask[expert_idx].squeeze(0) - idx, top_x = torch.where(mask) - current_state = flat_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype)) + def _patch_router(self, moe_layer: nn.Module) -> None: + """Patch the TopKRouter to inject aux-free bias into expert selection.""" + gate = getattr(moe_layer, "gate", None) + if gate is None: + LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch") + return + if getattr(gate, "_afb_patched", False): + return - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + # Capture reference to the MoE block for bias / counts access. + block_ref = moe_layer - moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined] - setattr(moe_layer, "_afb_patched", True) + def afb_router_forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) + router_probs = F.softmax(router_logits.float(), dim=-1) + + # Biased selection, unbiased weights + bias = block_ref._afb_bias + biased = router_probs + bias + _, router_indices = torch.topk(biased, self.top_k, dim=-1) + router_scores = torch.gather(router_probs, 1, router_indices) + + # Renormalize (Mixtral always normalizes; Qwen checks config) + if getattr(self, "norm_topk_prob", True): + router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True) + + # Accumulate counts for the bias-update callback + flat_idx = router_indices.reshape(-1) + counts = torch.bincount(flat_idx, minlength=self.num_experts) + block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype)) + + return router_probs, router_scores, router_indices + + gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined] + gate._afb_patched = True + moe_layer._afb_patched = True class Qwen3Adapter(MixtralAdapter): family = "qwen3_moe" def matches(self, model: nn.Module) -> bool: - return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe") + return getattr(getattr(model, "config", object()), "model_type", "") in ( + "qwen3_moe", + "qwen2_moe", + ) + + +class Qwen35MoeAdapter(MixtralAdapter): + """Adapter for Qwen 3.5 MoE models. + + Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert + is handled by the block forward (untouched by router-level patching). + """ + + family = "qwen3_5_moe" + + def matches(self, model: nn.Module) -> bool: + return getattr(getattr(model, "config", object()), "model_type", "") in ( + "qwen3_5_moe", + "qwen3_5_moe_text", + ) class BailingAdapter(BaseMoEAdapter): @@ -207,11 +249,15 @@ class BailingAdapter(BaseMoEAdapter): def get_num_experts(self, moe_layer: nn.Module) -> int: if hasattr(moe_layer, "num_experts"): - return int(getattr(moe_layer, "num_experts")) + return int(moe_layer.num_experts) cfg = getattr(moe_layer, "config", None) - return int(getattr(cfg, "num_experts")) + if cfg is None: + raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}") + return int(cfg.num_experts) - def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + def prepare( + self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim + ) -> None: self._register_aux_buffers(moe_layer, handle, shim) self._patch_bailing_gate(moe_layer) @@ -227,7 +273,7 @@ class BailingAdapter(BaseMoEAdapter): flat = hidden_states.view(-1, hidden_states.shape[-1]) logits = F.linear(flat.float(), self.weight.float()) scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype) - bias = getattr(moe_layer, "_afb_bias") + bias = moe_layer._afb_bias biased_scores = scores_unbiased + bias topk_vals, topk_idx = self.group_limited_topk(biased_scores) weights = torch.gather(scores_unbiased, 1, topk_idx) @@ -238,12 +284,12 @@ class BailingAdapter(BaseMoEAdapter): flat_topk = topk_idx.reshape(-1) counts = torch.bincount(flat_topk, minlength=bias.numel()) - getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype)) + moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype)) return topk_idx, weights.to(hidden_states.dtype), logits gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined] - setattr(gate, "_afb_patched", True) + gate._afb_patched = True class Llama4Adapter(BaseMoEAdapter): @@ -257,7 +303,9 @@ class Llama4Adapter(BaseMoEAdapter): if m.__class__.__name__ == "Llama4TextMoe": yield m - def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + def prepare( + self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim + ) -> None: self._register_aux_buffers(moe_layer, handle, shim) self._patch_llama4_router(moe_layer) @@ -270,9 +318,13 @@ class Llama4Adapter(BaseMoEAdapter): return def afb_router_forward(self, hidden_states: torch.Tensor): - flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1]) + flat = ( + hidden_states + if hidden_states.dim() == 2 + else hidden_states.view(-1, hidden_states.shape[-1]) + ) router_logits = F.linear(flat, self.weight, self.bias) - bias = getattr(moe_layer, "_afb_bias") + bias = moe_layer._afb_bias biased_logits = router_logits + bias _, router_indices = torch.topk(biased_logits, self.top_k, dim=1) unbiased_top = torch.gather(router_logits, 1, router_indices) @@ -281,15 +333,17 @@ class Llama4Adapter(BaseMoEAdapter): router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype) counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel()) - getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype)) + moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype)) return router_scores, router_logits router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined] - setattr(router, "_afb_patched", True) + router._afb_patched = True -def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]: +def discover_and_prepare_layers( + model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim +) -> list[LayerHandle]: """Discover MoE layers using the first matching adapter and attach per-layer buffers. Returns a list of layer handles for later routing patching and updates. @@ -313,7 +367,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter] try: top_k = adapter.get_top_k(layer) nE = adapter.get_num_experts(layer) - except Exception: + except (AttributeError, TypeError, ValueError): continue handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k) @@ -321,5 +375,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter] handles.append(handle) idx += 1 - LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing") + LOG.info( + f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing" + ) return handles diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index 31893e281..fd39c7dfc 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection. from __future__ import annotations -from typing import Optional, Any +from typing import Any, Optional import torch import torch.distributed as dist @@ -21,6 +21,7 @@ from .adapters import ( Llama4Adapter, MixtralAdapter, Qwen3Adapter, + Qwen35MoeAdapter, discover_and_prepare_layers, ) from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState @@ -47,9 +48,11 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): # Iterate prepared MoE layers and apply the bias update rule. self.shim.begin_step() for layer in self.layer_modules: - if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"): + if not hasattr(layer, "_afb_counts") or not hasattr( + layer, "_afb_layer_idx" + ): continue - counts = getattr(layer, "_afb_counts") + counts = layer._afb_counts if counts is None: continue counts = self.shim.all_reduce_counts(counts) @@ -57,7 +60,7 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): if layer_idx is None: counts.zero_() continue - bias = getattr(layer, "_afb_bias") + bias = layer._afb_bias counts_for_update = counts.to(bias.device) tokens_seen = int(counts_for_update.sum().item()) # local layer-state EMA and bias update @@ -143,12 +146,16 @@ class AuxFreeMoEPlugin(BasePlugin): return # Be conservative — skip known native aux-free families - native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in ( + native_auxfree = getattr( + getattr(model, "config", object()), "model_type", "" + ) in ( "deepseek_v3", "glm4_moe", ) if native_auxfree: - LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching") + LOG.info( + "AuxFreeMoE: model reports native aux-free routing; skipping patching" + ) return # Build aux-free state and shim @@ -171,6 +178,7 @@ class AuxFreeMoEPlugin(BasePlugin): adapters: list[BaseMoEAdapter] = [ MixtralAdapter(), Qwen3Adapter(), + Qwen35MoeAdapter(), BailingAdapter(), Llama4Adapter(), ] @@ -178,7 +186,7 @@ class AuxFreeMoEPlugin(BasePlugin): # For initial state sizing, we conservatively assume the first discovered layer defines nE n_layers = 0 n_experts = None - for m in model.modules(): + for _m in model.modules(): n_layers += 1 # upper bound — we will re-use bias slots sparsely device = next(model.parameters(), torch.tensor(0)).device if n_layers <= 0: @@ -186,7 +194,9 @@ class AuxFreeMoEPlugin(BasePlugin): if n_experts is None: # we'll set a minimal placeholder; prepare() will conceptually use module buffers instead n_experts = 2 - state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg) + state = AuxFreeState( + num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg + ) ep_size = getattr(cfg, "expert_parallel_size", None) ep_group = None if sync_group == "ep": @@ -207,11 +217,15 @@ class AuxFreeMoEPlugin(BasePlugin): def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]: if not dist.is_available() or not dist.is_initialized(): - LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world") + LOG.warning( + "AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world" + ) return None ep_size = getattr(cfg, "expert_parallel_size", None) if not ep_size or ep_size <= 1: - LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world") + LOG.warning( + "AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world" + ) return None world = dist.get_world_size() if world % ep_size != 0: diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index c6c01e255..963bf84f5 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -240,7 +240,16 @@ def _softmax_topk_route( top_k = base_gate.top_k num_experts = base_gate.num_experts - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + # Aux-free bias: biased selection, unbiased weights + afb_bias = getattr(moe_block, "_afb_bias", None) + if afb_bias is not None: + scores_for_choice = routing_weights + afb_bias + _, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1) + routing_weights = routing_weights.gather(1, selected_experts) + _accumulate_afb_counts(moe_block, selected_experts) + else: + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) if getattr(base_gate, "norm_topk_prob", True): routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) @@ -282,6 +291,11 @@ def _sigmoid_topk_route( else: scores_for_choice = router_probs + # Aux-free bias: stacks on top of e_score_correction_bias for selection + afb_bias = getattr(moe_block, "_afb_bias", None) + if afb_bias is not None: + scores_for_choice = scores_for_choice + afb_bias + # Group-based selection: pick top groups, mask the rest n_group = getattr(moe_block, "n_group", 1) if n_group > 1: @@ -307,6 +321,10 @@ def _sigmoid_topk_route( # Gather weights from original sigmoid scores (not bias-corrected) topk_weights = router_probs.gather(1, topk_indices) + # Accumulate counts for aux-free bias update + if afb_bias is not None: + _accumulate_afb_counts(moe_block, topk_indices) + # Optional renormalization + scaling if getattr(moe_block, "norm_topk_prob", True): topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) @@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta): ) +def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None: + """Accumulate per-expert token counts for aux-free bias updates.""" + afb_counts = getattr(moe_block, "_afb_counts", None) + if afb_counts is None: + return + flat_idx = topk_indices.reshape(-1) + counts = torch.bincount(flat_idx, minlength=afb_counts.numel()) + afb_counts.add_(counts.to(afb_counts.dtype)) + + # ============================================================================= # Shared expert helpers # ============================================================================= diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/sonicmoe/routing.py index fe2d12092..0b0da9fc8 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/sonicmoe/routing.py @@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies: Each model type maps to a (routing_fn, activation_type, router_attr) triple. When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. + +Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is +active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered +as buffers. The routing functions transparently inject the bias into expert +*selection* (biased topk) while keeping mixture *weights* from unbiased +scores, then accumulate per-expert token counts for the post-step bias update. """ import torch @@ -101,17 +107,25 @@ def softmax_topk_routing( router_logits = F.linear(hidden_states, gate.weight) # [T, E] router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + # Aux-free bias: biased selection, unbiased weights + afb_bias = getattr(moe_block, "_afb_bias", None) + scores_for_choice = router_probs + if afb_bias is not None: + scores_for_choice = router_probs + afb_bias + # Select top-k experts per token - top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each + top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each + + # When aux-free bias is active, gather unbiased weights and accumulate counts + if afb_bias is not None: + top_values = router_probs.gather(1, top_indices) + _accumulate_afb_counts(moe_block, top_indices) # Renormalize if configured (default True for models without the attribute, # e.g. Mixtral/MiniMax which always normalize) if getattr(gate, "norm_topk_prob", True): top_values = top_values / top_values.sum(dim=-1, keepdim=True) - # no-op: matches transformers which casts to softmax output dtype (float32). - # top_values = top_values.to(router_probs.dtype) - # Flatten for moe_general_routing_inputs. # Token indices are naturally sorted ascending from the [T, K] layout: # [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE. @@ -142,7 +156,11 @@ def softmax_group_topk_routing( router_logits = F.linear(hidden_states, gate.weight) # [T, E] router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + # Aux-free bias: inject before group selection / topk + afb_bias = getattr(moe_block, "_afb_bias", None) scores_for_choice = router_probs + if afb_bias is not None: + scores_for_choice = router_probs + afb_bias # Group selection: pick top groups, mask the rest if n_group > 1: @@ -164,6 +182,10 @@ def softmax_group_topk_routing( topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] topk_weights = router_probs.gather(1, topk_indices) + # Accumulate counts for aux-free bias update + if afb_bias is not None: + _accumulate_afb_counts(moe_block, topk_indices) + # Renormalization + scaling norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) if norm_topk_prob: @@ -233,6 +255,11 @@ def sigmoid_topk_routing( ) scores_for_choice = router_probs + e_score_correction_bias + # Aux-free bias: stacks on top of e_score_correction_bias for selection + afb_bias = getattr(moe_block, "_afb_bias", None) + if afb_bias is not None: + scores_for_choice = scores_for_choice + afb_bias + # Group-based selection: pick top groups, mask the rest (skip when n_group == 1) if n_group > 1: group_scores = ( @@ -256,6 +283,10 @@ def sigmoid_topk_routing( # Gather weights from original sigmoid scores (not bias-corrected) topk_weights = router_probs.gather(1, topk_indices) + # Accumulate counts for aux-free bias update + if afb_bias is not None: + _accumulate_afb_counts(moe_block, topk_indices) + # Optional renormalization + scaling norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) if norm_topk_prob: @@ -276,3 +307,19 @@ def sigmoid_topk_routing( flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + +def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None: + """Accumulate per-expert token counts for the aux-free bias update. + + Called when ``moe_block._afb_bias`` is present (registered by the + ``aux_free_router`` plugin). The counts are later consumed by the + ``MoeAuxFreeBiasUpdateCallback`` at each training step. + """ + afb_counts = getattr(moe_block, "_afb_counts", None) + if afb_counts is None: + return + num_experts = afb_counts.numel() + flat_idx = topk_indices.reshape(-1) + counts = torch.bincount(flat_idx, minlength=num_experts) + afb_counts.add_(counts.to(afb_counts.dtype)) diff --git a/tests/unit/test_aux_free_adapters.py b/tests/unit/test_aux_free_adapters.py index 3bc3ac8e5..43457679f 100644 --- a/tests/unit/test_aux_free_adapters.py +++ b/tests/unit/test_aux_free_adapters.py @@ -2,14 +2,13 @@ import os import sys import tempfile import unittest +from importlib import util as importlib_util +from pathlib import Path from types import SimpleNamespace import torch import torch.distributed as dist import torch.nn as nn -from importlib import util as importlib_util -from pathlib import Path - from huggingface_hub import snapshot_download from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin @@ -96,16 +95,21 @@ def _build_bailing_model(): def _build_llama4_model(): - from transformers import Llama4TextConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - config = Llama4TextConfig( + # Build config without __post_init__ validation (works around a + # huggingface_hub strict-dataclass type mismatch for layer_types). + config = object.__new__(__import__("transformers").Llama4TextConfig) + config.__dict__.update( hidden_size=16, intermediate_size=32, num_local_experts=4, num_attention_heads=2, num_key_value_heads=2, num_experts_per_tok=2, + num_hidden_layers=2, + hidden_act="silu", + layer_types=None, ) layer = Llama4TextMoe(config) @@ -148,6 +152,38 @@ def _build_mixtral_model(): return DummyModel(layer), layer +def _build_qwen35_moe_model(): + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import ( + Qwen3_5MoeTextConfig, + ) + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeSparseMoeBlock, + ) + + config = Qwen3_5MoeTextConfig( + hidden_size=16, + moe_intermediate_size=32, + shared_expert_intermediate_size=32, + num_experts=4, + num_experts_per_tok=2, + num_attention_heads=2, + num_key_value_heads=2, + num_hidden_layers=2, + ) + layer = Qwen3_5MoeSparseMoeBlock(config) + + class DummyModel(nn.Module): + def __init__(self, moe_layer): + super().__init__() + self.moe = moe_layer + self.config = SimpleNamespace(model_type="qwen3_5_moe") + + def forward(self, hidden_states): + return self.moe(hidden_states) + + return DummyModel(layer), layer + + def _run_callback(plugin, cfg, *, args=None, state=None, control=None): if args is None: args = SimpleNamespace(logging_steps=1) @@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase): _run_callback(plugin, cfg) self.assertEqual(torch.count_nonzero(block._afb_counts), 0) - self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))) + self.assertFalse( + torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)) + ) def test_llama4_adapter_biases_router_selection(self): model, layer = _build_llama4_model() @@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase): _run_callback(plugin, cfg) self.assertEqual(torch.count_nonzero(layer._afb_counts), 0) - self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))) + self.assertFalse( + torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)) + ) def test_bias_warmup_respected(self): model, block = _build_bailing_model() @@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase): # Warmup steps should leave bias untouched. _step() - self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))) + self.assertTrue( + torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)) + ) _step() - self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))) + self.assertTrue( + torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)) + ) # Third step exceeds warmup -> bias should update. _step() self.assertGreater(torch.count_nonzero(block._afb_bias), 0) - def test_mixtral_adapter_respects_native_forward(self): + def test_mixtral_adapter_patches_router_not_forward(self): + """Verify that aux-free patches the router (gate) only, and the + v5 block forward signature (single tensor return) is preserved.""" model, layer = _build_mixtral_model() - layer.jitter_noise = 0.0 # avoid stochasticity for comparison - - hidden_dim = layer.config.hidden_size - hidden = torch.randn(2, 3, hidden_dim) - baseline_out, baseline_logits = layer(hidden.clone()) - cfg = _cfg() plugin = AuxFreeMoEPlugin() plugin.post_model_build(cfg, model) - patched_out, patched_logits = layer(hidden.clone()) - self.assertTrue(torch.allclose(baseline_out, patched_out)) - self.assertTrue(torch.allclose(baseline_logits, patched_logits)) + # Gate should be patched, not the block forward + self.assertTrue(getattr(layer.gate, "_afb_patched", False)) + self.assertTrue(getattr(layer, "_afb_patched", False)) + + # v5 block forward returns a single tensor (not a tuple with logits) + hidden = torch.randn(2, 3, layer.config.hidden_size) + out = layer(hidden) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, hidden.shape) + + # Counts should have been accumulated self.assertGreater(torch.count_nonzero(layer._afb_counts), 0) _run_callback(plugin, cfg) + def test_mixtral_adapter_bias_affects_selection(self): + """When bias is large for one expert, it should be selected more often.""" + model, layer = _build_mixtral_model() + cfg = _cfg() + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + # Set a large bias for expert 0 to force its selection + layer._afb_bias.zero_() + layer._afb_bias[0] = 10.0 + + hidden = torch.randn(2, 8, layer.config.hidden_size) + num_tokens = 2 * 8 # batch * seq + layer(hidden) + + # With top_k=2, expert 0 should appear in every token's selection + # (once per token = num_tokens counts, not num_tokens * top_k) + counts = layer._afb_counts.clone() + self.assertEqual( + int(counts[0].item()), + num_tokens, + msg="Expert 0 should be selected for every token when heavily biased", + ) + + def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self): + """Verify Qwen 3.5 MoE: router is patched, shared expert is untouched, + output includes shared expert contribution.""" + model, layer = _build_qwen35_moe_model() + cfg = _cfg() + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + # Gate should be patched + self.assertTrue(getattr(layer.gate, "_afb_patched", False)) + self.assertTrue(getattr(layer, "_afb_patched", False)) + # Shared expert should be unmodified + self.assertTrue(hasattr(layer, "shared_expert")) + self.assertTrue(hasattr(layer, "shared_expert_gate")) + + # Forward should return a single tensor (shared + routed) + hidden_size = layer.gate.hidden_dim + hidden = torch.randn(2, 3, hidden_size) + out = layer(hidden) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, hidden.shape) + + # Counts should have been accumulated + self.assertGreater(torch.count_nonzero(layer._afb_counts), 0) + + def test_qwen35_moe_adapter_bias_updates(self): + """Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE.""" + model, layer = _build_qwen35_moe_model() + cfg = _cfg() + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + hidden_size = layer.gate.hidden_dim + hidden = torch.randn(2, 4, hidden_size) + layer(hidden) + + # Bias should start at zero + self.assertTrue( + torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias)) + ) + + _run_callback(plugin, cfg) + + # After callback: counts reset, EMA updated, bias updated + self.assertEqual(torch.count_nonzero(layer._afb_counts), 0) + self.assertFalse( + torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)) + ) + + def test_qwen35_moe_adapter_model_type_matching(self): + """Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text.""" + from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter + + adapter = Qwen35MoeAdapter() + + model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe")) + model_text = SimpleNamespace( + config=SimpleNamespace(model_type="qwen3_5_moe_text") + ) + model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe")) + + self.assertTrue(adapter.matches(model_moe)) + self.assertTrue(adapter.matches(model_text)) + self.assertFalse(adapter.matches(model_other)) + def test_ep_group_resolution_deferred_until_dist_ready(self): if dist.is_available() and dist.is_initialized(): dist.destroy_process_group() @@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase): tmp_init = tempfile.NamedTemporaryFile(delete=False) tmp_init.close() init_method = f"file://{tmp_init.name}" - dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0) + dist.init_process_group( + backend="gloo", init_method=init_method, world_size=1, rank=0 + ) try: hidden = torch.randn(2, 3, block.config.hidden_size) block(hidden) @@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase): def test_telemetry_logging(self): model, layer = _build_mixtral_model() - layer.jitter_noise = 0.0 cfg = _cfg() plugin = AuxFreeMoEPlugin() plugin.post_model_build(cfg, model) @@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase): self.assertIn("moe_afb/l0_load_max", telemetry) self.assertIn("moe_afb/l0_bias_abs_max", telemetry) + def test_get_num_experts_v5_attribute_paths(self): + """Verify get_num_experts works with v5 attribute layout where + num_experts is on gate/experts sub-modules, not the block.""" + from axolotl.integrations.aux_free_router.adapters import MixtralAdapter + + adapter = MixtralAdapter() + + # Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block) + block = SimpleNamespace( + gate=SimpleNamespace(num_experts=8), + experts=SimpleNamespace(num_experts=8), + ) + self.assertEqual(adapter.get_num_experts(block), 8) + + # Also works when num_experts is directly on block + block2 = SimpleNamespace(num_experts=4) + self.assertEqual(adapter.get_num_experts(block2), 4) + + +class TestAuxFreeKernelComposition(unittest.TestCase): + """Tests that aux-free bias composes correctly with kernel routing.""" + + def test_sonicmoe_softmax_routing_with_afb_bias(self): + """SonicMoE softmax routing should use biased selection / unbiased weights.""" + from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + + num_experts = 4 + top_k = 2 + hidden_dim = 16 + T = 6 + + # Build a mock MoE block with gate attributes + gate = nn.Linear(hidden_dim, num_experts, bias=False) + gate.top_k = top_k + gate.num_experts = num_experts + gate.norm_topk_prob = True + + moe_block = SimpleNamespace(gate=gate) + hidden = torch.randn(T, hidden_dim) + + # Baseline: no bias + scores_base, tok_base, exp_base, logits_base = softmax_topk_routing( + hidden, moe_block + ) + self.assertEqual(scores_base.shape[0], T * top_k) + + # Now register aux-free buffers and set heavy bias on expert 0 + moe_block._afb_bias = torch.zeros(num_experts) + moe_block._afb_bias[0] = 100.0 + moe_block._afb_counts = torch.zeros(num_experts) + + scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing( + hidden, moe_block + ) + + # Expert 0 should be selected for every token + self.assertTrue( + (exp_biased == 0).any(), + "Expert 0 should appear in selections when heavily biased", + ) + # Counts should have been accumulated + self.assertGreater(moe_block._afb_counts[0].item(), 0) + # Total counts should equal T * top_k + self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k) + + def test_sonicmoe_routing_without_bias_unchanged(self): + """Without _afb_bias, routing should produce identical results.""" + from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + + num_experts = 4 + top_k = 2 + hidden_dim = 16 + + gate = nn.Linear(hidden_dim, num_experts, bias=False) + gate.top_k = top_k + gate.num_experts = num_experts + gate.norm_topk_prob = True + + moe_block = SimpleNamespace(gate=gate) + hidden = torch.randn(4, hidden_dim) + + # Without _afb_bias attribute + scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block) + + # With _afb_bias = zeros (should be equivalent) + moe_block._afb_bias = torch.zeros(num_experts) + moe_block._afb_counts = torch.zeros(num_experts) + scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block) + + torch.testing.assert_close(scores1, scores2) + torch.testing.assert_close(exp1, exp2) + + @unittest.skipUnless( + importlib_util.find_spec("triton") is not None, + "triton not installed (required by scattermoe)", + ) + def test_scattermoe_softmax_routing_with_afb_bias(self): + """ScatterMoE softmax routing should use biased selection / unbiased weights.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _softmax_topk_route, + ) + + num_experts = 4 + top_k = 2 + hidden_dim = 16 + T = 6 + + gate_weight = torch.randn(num_experts, hidden_dim) + base_gate = SimpleNamespace( + top_k=top_k, + num_experts=num_experts, + norm_topk_prob=True, + weight=gate_weight, + ) + + moe_block = SimpleNamespace() + hidden = torch.randn(T, hidden_dim) + + # Baseline without bias + w_base, e_base, _, _ = _softmax_topk_route( + moe_block, base_gate, hidden, gate_weight, None + ) + + # With heavy bias on expert 0 + moe_block._afb_bias = torch.zeros(num_experts) + moe_block._afb_bias[0] = 100.0 + moe_block._afb_counts = torch.zeros(num_experts) + + w_biased, e_biased, _, _ = _softmax_topk_route( + moe_block, base_gate, hidden, gate_weight, None + ) + + # Expert 0 should appear in all selections + self.assertTrue((e_biased == 0).any()) + # Counts accumulated + self.assertGreater(moe_block._afb_counts[0].item(), 0) + self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k) + + def test_kernel_routing_skips_router_patch(self): + """When a kernel backend has patched the block class, the adapter + should skip patching the router (buffers are still registered).""" + from axolotl.integrations.aux_free_router.adapters import MixtralAdapter + + adapter = MixtralAdapter() + + # Create a mock layer whose class has _original_forward (SonicMoE marker) + class PatchedBlock(nn.Module): + _original_forward = True # SonicMoE marker + + def __init__(self): + super().__init__() + self.gate = nn.Linear(16, 4, bias=False) + self.gate.top_k = 2 + self.gate.num_experts = 4 + self.gate.hidden_dim = 16 + self.experts = nn.Linear(16, 16) # placeholder + + layer = PatchedBlock() + self.assertTrue(adapter.uses_kernel_routing(layer)) + + # Gate should NOT be patched (kernel handles routing) + self.assertFalse(getattr(layer.gate, "_afb_patched", False)) + + def test_adapter_buffers_registered_even_with_kernel(self): + """Even when kernel routing is active, aux-free buffers must be + registered on the MoE block so the kernel routing can find them.""" + from axolotl.integrations.aux_free_router.adapters import ( + LayerHandle, + MixtralAdapter, + ) + from axolotl.integrations.aux_free_router.core import ( + AuxFreeConfig, + AuxFreeShim, + AuxFreeState, + ) + + class PatchedBlock(nn.Module): + _original_forward = True + + def __init__(self): + super().__init__() + self.gate = nn.Linear(16, 4, bias=False) + self.gate.top_k = 2 + self.gate.num_experts = 4 + self.gate.hidden_dim = 16 + self.experts = nn.Linear(16, 16) + + layer = PatchedBlock() + adapter = MixtralAdapter() + cfg = AuxFreeConfig() + state = AuxFreeState( + num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg + ) + shim = AuxFreeShim(state=state) + handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2) + + adapter.prepare(layer, handle, shim) + + # Buffers should be registered for kernel routing to use + self.assertTrue(hasattr(layer, "_afb_bias")) + self.assertTrue(hasattr(layer, "_afb_counts")) + self.assertTrue(hasattr(layer, "_afb_ema")) + # But gate should NOT be patched + self.assertFalse(getattr(layer.gate, "_afb_patched", False)) + if __name__ == "__main__": unittest.main()