Compare commits
14 Commits
activeblue
...
lhl-moe-au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6636e5de7e | ||
|
|
0a566d7a15 | ||
|
|
5acb1b0ade | ||
|
|
4009a2ba5f | ||
|
|
66b2ab8414 | ||
|
|
676d5e855d | ||
|
|
966a4555db | ||
|
|
ad0c825bcb | ||
|
|
46d677876e | ||
|
|
6eac9ac372 | ||
|
|
949cdf01eb | ||
|
|
a0019021dd | ||
|
|
2af7475fdf | ||
|
|
3e4688289c |
50
src/axolotl/integrations/aux_free_router/README.md
Normal file
50
src/axolotl/integrations/aux_free_router/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Aux-Loss-Free MoE Router Plugin
|
||||
|
||||
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
|
||||
|
||||
Summary
|
||||
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
|
||||
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
|
||||
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
|
||||
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
|
||||
|
||||
Enable
|
||||
- Add the plugin to your YAML, then set the aux-free toggle:
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
|
||||
|
||||
moe_balance_type: noaux_tc
|
||||
moe_update_rate: 0.01 # default if unset
|
||||
moe_update_momentum: 0.9 # default if unset
|
||||
moe_bias_cap: 2.0 # default if unset
|
||||
moe_afb_warmup_steps: 100 # optional
|
||||
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
||||
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
||||
|
||||
Config keys
|
||||
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
||||
- moe_update_rate: bias update rate (gamma). Default: 0.01.
|
||||
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
|
||||
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||
|
||||
Compatibility
|
||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
||||
|
||||
Notes
|
||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainer’s `logging_steps` cadence.
|
||||
- Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert.
|
||||
|
||||
Telemetry metrics
|
||||
- `moe_afb/l{idx}_load_min|mean|max`: token frequency per expert after reduction (0–1 range, sums to 1).
|
||||
- `moe_afb/l{idx}_bias_abs_max`: absolute maximum of the learned bias for the layer.
|
||||
- `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator).
|
||||
|
||||
Usage tips
|
||||
- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainer’s logging cadence.
|
||||
- Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term.
|
||||
9
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
9
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Aux-loss-free (AFB) MoE router integration package."""
|
||||
|
||||
from .args import AuxFreeRouterArgs
|
||||
from .plugin import AuxFreeMoEPlugin
|
||||
|
||||
__all__ = [
|
||||
"AuxFreeMoEPlugin",
|
||||
"AuxFreeRouterArgs",
|
||||
]
|
||||
393
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
393
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Architecture-specific adapters for aux-loss-free MoE routing.
|
||||
|
||||
Each adapter discovers MoE layers for a model family and patches only the
|
||||
router/gate to inject per-expert bias into expert selection while keeping
|
||||
mixture weights from unbiased logits. Expert dispatch is left untouched so
|
||||
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .core import AuxFreeShim
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerHandle:
|
||||
layer: nn.Module
|
||||
layer_idx: int
|
||||
num_experts: int
|
||||
top_k: int
|
||||
|
||||
|
||||
class BaseMoEAdapter:
|
||||
"""Base adapter that discovers MoE layers and patches their routing.
|
||||
|
||||
Concrete adapters implement discovery, attribute extraction, and
|
||||
architecture-specific router patching.
|
||||
"""
|
||||
|
||||
family: str = "generic"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||
return False
|
||||
|
||||
def find_moe_layers(
|
||||
self, model: nn.Module
|
||||
) -> Iterable[nn.Module]: # pragma: no cover
|
||||
return []
|
||||
|
||||
def get_top_k(self, moe_layer: nn.Module) -> int:
|
||||
"""Resolve top_k from the MoE layer, checking common attribute paths."""
|
||||
for attr_path in [
|
||||
("top_k",),
|
||||
("num_experts_per_tok",),
|
||||
("gate", "top_k"),
|
||||
("router", "top_k"),
|
||||
]:
|
||||
obj: object = moe_layer
|
||||
for attr in attr_path:
|
||||
obj = getattr(obj, attr, None)
|
||||
if obj is None:
|
||||
break
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
return 2
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
|
||||
for attr_path in [
|
||||
("num_experts",),
|
||||
("num_local_experts",),
|
||||
("gate", "num_experts"),
|
||||
("router", "num_experts"),
|
||||
("experts", "num_experts"),
|
||||
]:
|
||||
obj: object = moe_layer
|
||||
for attr in attr_path:
|
||||
obj = getattr(obj, attr, None)
|
||||
if obj is None:
|
||||
break
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||
|
||||
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||
# Best-effort: zero router aux loss coef if present
|
||||
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||
try:
|
||||
model_or_layer.router_aux_loss_coef = 0.0
|
||||
except Exception: # pragma: no cover - non-critical
|
||||
LOG.debug(
|
||||
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
|
||||
type(model_or_layer).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _register_aux_buffers(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
p = next(moe_layer.parameters(), None)
|
||||
b = next(moe_layer.buffers(), None)
|
||||
device = (
|
||||
p.device
|
||||
if p is not None
|
||||
else (b.device if b is not None else torch.device("cpu"))
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_bias"):
|
||||
moe_layer.register_buffer(
|
||||
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_counts"):
|
||||
moe_layer.register_buffer(
|
||||
"_afb_counts", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_ema"):
|
||||
moe_layer.register_buffer(
|
||||
"_afb_ema", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
"""Attach per-layer buffers. Subclasses override to also patch routing."""
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
|
||||
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
|
||||
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
|
||||
already replaced the block forward, meaning the routing is handled
|
||||
inside the kernel forward and we should NOT patch the router."""
|
||||
cls = type(moe_layer)
|
||||
# SonicMoE stores the original forward when it patches a class.
|
||||
if hasattr(cls, "_original_forward"):
|
||||
return True
|
||||
# ScatterMoE replaces via kernels library; check for the marker.
|
||||
if hasattr(cls, "_kernel_forward"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MixtralAdapter(BaseMoEAdapter):
|
||||
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
|
||||
routing so that biased logits drive expert *selection* while unbiased
|
||||
softmax scores drive mixture *weights*.
|
||||
|
||||
Works with transformers v5 where experts are fused 3D tensors and
|
||||
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
|
||||
"""
|
||||
|
||||
family = "mixtral"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return (
|
||||
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||
)
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||
yield m
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
if not self.uses_kernel_routing(moe_layer):
|
||||
self._patch_router(moe_layer)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: kernel backend detected on %s; "
|
||||
"skipping router patch (kernel routing handles bias)",
|
||||
type(moe_layer).__name__,
|
||||
)
|
||||
|
||||
def _patch_router(self, moe_layer: nn.Module) -> None:
|
||||
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
# Capture reference to the MoE block for bias / counts access.
|
||||
block_ref = moe_layer
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_logits = F.linear(hidden_states, self.weight)
|
||||
router_probs = F.softmax(router_logits.float(), dim=-1)
|
||||
|
||||
# Biased selection, unbiased weights
|
||||
bias = block_ref._afb_bias
|
||||
biased = router_probs + bias
|
||||
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
|
||||
router_scores = torch.gather(router_probs, 1, router_indices)
|
||||
|
||||
# Renormalize (Mixtral always normalizes; Qwen checks config)
|
||||
if getattr(self, "norm_topk_prob", True):
|
||||
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Accumulate counts for the bias-update callback
|
||||
flat_idx = router_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=self.num_experts)
|
||||
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
|
||||
|
||||
return router_probs, router_scores, router_indices
|
||||
|
||||
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
gate._afb_patched = True
|
||||
moe_layer._afb_patched = True
|
||||
|
||||
|
||||
class Qwen3Adapter(MixtralAdapter):
|
||||
family = "qwen3_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"qwen3_moe",
|
||||
"qwen2_moe",
|
||||
)
|
||||
|
||||
|
||||
class Qwen35MoeAdapter(MixtralAdapter):
|
||||
"""Adapter for Qwen 3.5 MoE models.
|
||||
|
||||
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
|
||||
is handled by the block forward (untouched by router-level patching).
|
||||
"""
|
||||
|
||||
family = "qwen3_5_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"qwen3_5_moe",
|
||||
"qwen3_5_moe_text",
|
||||
)
|
||||
|
||||
|
||||
class BailingAdapter(BaseMoEAdapter):
|
||||
family = "bailing_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
cfg = getattr(model, "config", None)
|
||||
if cfg is None:
|
||||
return False
|
||||
model_type = getattr(cfg, "model_type", "") or ""
|
||||
if model_type in ("bailing_moe", "bailing_moe_v2", "ring_moe", "ring"):
|
||||
return True
|
||||
cfg_name = cfg.__class__.__name__.lower()
|
||||
return "bailingmoev2" in cfg_name or "ring" in cfg_name
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
||||
yield m
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
if hasattr(moe_layer, "num_experts"):
|
||||
return int(moe_layer.num_experts)
|
||||
cfg = getattr(moe_layer, "config", None)
|
||||
if cfg is None:
|
||||
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||
return int(cfg.num_experts)
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_bailing_gate(moe_layer)
|
||||
|
||||
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
logits = F.linear(flat.float(), self.weight.float())
|
||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
bias = moe_layer._afb_bias
|
||||
biased_scores = scores_unbiased + bias
|
||||
_, topk_idx = self.group_limited_topk(biased_scores)
|
||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||
if self.top_k > 1:
|
||||
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||
weights = weights / denom
|
||||
weights = weights * self.routed_scaling_factor
|
||||
|
||||
flat_topk = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||
|
||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
gate._afb_patched = True
|
||||
|
||||
|
||||
class Llama4Adapter(BaseMoEAdapter):
|
||||
family = "llama4"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "Llama4TextMoe":
|
||||
yield m
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_llama4_router(moe_layer)
|
||||
|
||||
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||
router = getattr(moe_layer, "router", None)
|
||||
if router is None:
|
||||
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
||||
return
|
||||
if getattr(router, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
flat = (
|
||||
hidden_states
|
||||
if hidden_states.dim() == 2
|
||||
else hidden_states.view(-1, hidden_states.shape[-1])
|
||||
)
|
||||
router_logits = F.linear(flat, self.weight, self.bias)
|
||||
bias = moe_layer._afb_bias
|
||||
biased_logits = router_logits + bias
|
||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||
router_scores = torch.full_like(router_logits, float("-inf"))
|
||||
router_scores.scatter_(1, router_indices, unbiased_top)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||
|
||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return router_scores, router_logits
|
||||
|
||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||
router._afb_patched = True
|
||||
|
||||
|
||||
def discover_and_prepare_layers(
|
||||
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
|
||||
) -> list[LayerHandle]:
|
||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||
|
||||
Returns a list of layer handles for later routing patching and updates.
|
||||
"""
|
||||
handles: list[LayerHandle] = []
|
||||
adapter: Optional[BaseMoEAdapter] = None
|
||||
for a in adapters:
|
||||
if a.matches(model):
|
||||
adapter = a
|
||||
break
|
||||
|
||||
if adapter is None:
|
||||
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
|
||||
return handles
|
||||
|
||||
# disable aux loss at model level if possible
|
||||
adapter.disable_aux_loss(getattr(model, "config", model))
|
||||
|
||||
idx = 0
|
||||
for layer in adapter.find_moe_layers(model):
|
||||
try:
|
||||
top_k = adapter.get_top_k(layer)
|
||||
nE = adapter.get_num_experts(layer)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
continue
|
||||
|
||||
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||
adapter.prepare(layer, handle, shim)
|
||||
handles.append(handle)
|
||||
idx += 1
|
||||
|
||||
LOG.info(
|
||||
"AuxFreeMoE: prepared %d %s layers for aux-free routing",
|
||||
len(handles),
|
||||
adapter.family,
|
||||
)
|
||||
return handles
|
||||
71
src/axolotl/integrations/aux_free_router/args.py
Normal file
71
src/axolotl/integrations/aux_free_router/args.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Plugin args for the Aux-Loss-Free MoE router integration.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AuxFreeRouterArgs(BaseModel):
|
||||
"""
|
||||
Input args for Aux-Loss-Free MoE routing.
|
||||
"""
|
||||
|
||||
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, "
|
||||
"'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. "
|
||||
"Defaults to model's native behavior when unset."
|
||||
},
|
||||
)
|
||||
moe_update_rate: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
|
||||
"If unset, plugin default is 0.01."
|
||||
},
|
||||
)
|
||||
moe_update_momentum: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "EMA momentum for expert load smoothing (0-1). "
|
||||
"If unset, plugin default is 0.9."
|
||||
},
|
||||
)
|
||||
moe_bias_cap: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Absolute clamp for expert bias magnitude. "
|
||||
"If unset, plugin default is 2.0."
|
||||
},
|
||||
)
|
||||
moe_afb_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of initial steps to delay aux-free bias updates, "
|
||||
"allowing routing to stabilize. If unset, plugin default is 0."
|
||||
},
|
||||
)
|
||||
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Reduction group for expert load counts: 'world' (DP) or "
|
||||
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
|
||||
},
|
||||
)
|
||||
166
src/axolotl/integrations/aux_free_router/core.py
Normal file
166
src/axolotl/integrations/aux_free_router/core.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuxFreeConfig:
|
||||
rate: float = 0.01
|
||||
momentum: float = 0.9
|
||||
bias_cap: float = 2.0
|
||||
warmup_steps: int = 0
|
||||
sync_group: str = "world" # or "ep"
|
||||
|
||||
|
||||
class AuxFreeState:
|
||||
"""Holds per-layer bias and EMA load buffers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
device: torch.device,
|
||||
cfg: AuxFreeConfig,
|
||||
):
|
||||
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||
self.ema_load = [
|
||||
torch.zeros(num_experts, device=device) for _ in range(num_layers)
|
||||
]
|
||||
self.cfg = cfg
|
||||
self.steps = 0
|
||||
|
||||
|
||||
class AuxFreeShim:
|
||||
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: AuxFreeState,
|
||||
ep_group: Optional[dist.ProcessGroup] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.ep_group = ep_group
|
||||
self._ep_size = ep_size
|
||||
self._ep_group_pending = (
|
||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||
)
|
||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_experts(
|
||||
self, layer_idx: int, logits: torch.Tensor, top_k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_bias"):
|
||||
b = module._afb_bias
|
||||
else:
|
||||
b = self.state.bias[layer_idx]
|
||||
biased = logits + b # bias is a buffer
|
||||
_topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||
return topk_idx, weights
|
||||
|
||||
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||
self._layer_modules[layer_idx] = module
|
||||
bias = module._afb_bias
|
||||
ema = module._afb_ema
|
||||
# Keep state views pointing to the same tensors to avoid drift.
|
||||
if layer_idx < len(self.state.bias):
|
||||
self.state.bias[layer_idx] = bias
|
||||
if layer_idx < len(self.state.ema_load):
|
||||
self.state.ema_load[layer_idx] = ema
|
||||
|
||||
def begin_step(self) -> None:
|
||||
"""Call once per optimizer step before per-layer updates."""
|
||||
self.state.steps += 1
|
||||
|
||||
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
|
||||
return self._prev_bias_sign.get(layer_idx)
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||
self._maybe_init_ep_group()
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return counts
|
||||
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
||||
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
|
||||
return counts
|
||||
|
||||
@torch.no_grad()
|
||||
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
||||
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
||||
cfg = self.state.cfg
|
||||
if self.state.steps <= cfg.warmup_steps:
|
||||
return
|
||||
|
||||
nE = step_counts.numel()
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_ema"):
|
||||
ema = module._afb_ema
|
||||
bias = module._afb_bias
|
||||
else:
|
||||
ema = self.state.ema_load[layer_idx]
|
||||
bias = self.state.bias[layer_idx]
|
||||
counts = step_counts.to(ema.device)
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||
target = 1.0 / float(nE)
|
||||
delta = cfg.rate * (target - ema)
|
||||
# optional mean-centering to keep sum(bias) ~ 0
|
||||
delta = delta - delta.mean()
|
||||
bias.add_(delta)
|
||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
|
||||
|
||||
def _maybe_init_ep_group(self) -> None:
|
||||
if not self._ep_group_pending:
|
||||
return
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return
|
||||
ep_size = self._ep_size
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
if ep_size == world:
|
||||
self.ep_group = dist.group.WORLD
|
||||
else:
|
||||
rank = dist.get_rank()
|
||||
group_start = (rank // ep_size) * ep_size
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
self.ep_group = dist.new_group(ranks)
|
||||
LOG.info(
|
||||
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
||||
ep_size,
|
||||
dist.get_world_size(),
|
||||
)
|
||||
self._ep_group_pending = False
|
||||
267
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
267
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Aux-loss-free MoE Router Plugin for Axolotl.
|
||||
|
||||
This plugin wires an aux-free gating option into compatible MoE models using
|
||||
unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .adapters import (
|
||||
BailingAdapter,
|
||||
BaseMoEAdapter,
|
||||
Llama4Adapter,
|
||||
MixtralAdapter,
|
||||
Qwen3Adapter,
|
||||
Qwen35MoeAdapter,
|
||||
discover_and_prepare_layers,
|
||||
)
|
||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
"""Post-step callback to update aux-free biases from accumulated expert counts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shim: AuxFreeShim,
|
||||
layer_modules: list[torch.nn.Module],
|
||||
trainer: Any,
|
||||
):
|
||||
self.shim = shim
|
||||
self.layer_modules = layer_modules
|
||||
self.trainer = trainer
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
self._telemetry_buffer: dict[int, dict[str, float]] = {}
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||
# Iterate prepared MoE layers and apply the bias update rule.
|
||||
self.shim.begin_step()
|
||||
for layer in self.layer_modules:
|
||||
if not hasattr(layer, "_afb_counts") or not hasattr(
|
||||
layer, "_afb_layer_idx"
|
||||
):
|
||||
continue
|
||||
counts = layer._afb_counts
|
||||
if counts is None:
|
||||
continue
|
||||
counts = self.shim.all_reduce_counts(counts)
|
||||
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
||||
if layer_idx is None:
|
||||
counts.zero_()
|
||||
continue
|
||||
bias = layer._afb_bias
|
||||
counts_for_update = counts.to(bias.device)
|
||||
tokens_seen = int(counts_for_update.sum().item())
|
||||
# local layer-state EMA and bias update
|
||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
|
||||
# reset step counts
|
||||
counts.zero_()
|
||||
|
||||
if self._should_log(args, state) and self._telemetry_buffer:
|
||||
logs: dict[str, float] = {}
|
||||
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||
prefix = f"moe_afb/l{layer_idx}_"
|
||||
for key, value in metrics.items():
|
||||
logs[f"{prefix}{key}"] = value
|
||||
if logs and hasattr(self.trainer, "log"):
|
||||
self.trainer.log(logs)
|
||||
self._telemetry_buffer.clear()
|
||||
return control
|
||||
|
||||
def _collect_telemetry(
|
||||
self,
|
||||
layer_idx: int,
|
||||
counts: torch.Tensor,
|
||||
tokens_seen: int,
|
||||
bias: torch.Tensor,
|
||||
) -> None:
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
load_min = freq.min().item()
|
||||
load_mean = freq.mean().item()
|
||||
load_max = freq.max().item()
|
||||
bias_abs_max = bias.abs().max().item()
|
||||
|
||||
prev_sign = self._prev_bias_sign.get(layer_idx)
|
||||
current_sign = torch.sign(bias.detach())
|
||||
if prev_sign is None or prev_sign.shape != current_sign.shape:
|
||||
oscillation = 0.0
|
||||
else:
|
||||
changed = (current_sign != prev_sign) & (
|
||||
(current_sign != 0) | (prev_sign != 0)
|
||||
)
|
||||
if changed.numel() == 0:
|
||||
oscillation = 0.0
|
||||
else:
|
||||
oscillation = changed.float().mean().item()
|
||||
self._prev_bias_sign[layer_idx] = current_sign.clone()
|
||||
|
||||
self._telemetry_buffer[layer_idx] = {
|
||||
"load_min": load_min,
|
||||
"load_mean": load_mean,
|
||||
"load_max": load_max,
|
||||
"bias_abs_max": bias_abs_max,
|
||||
"bias_sign_flip_frac": oscillation,
|
||||
}
|
||||
|
||||
def _should_log(self, args, state) -> bool:
|
||||
interval = getattr(args, "logging_steps", 0)
|
||||
if not interval:
|
||||
return False
|
||||
try:
|
||||
interval = max(1, int(interval))
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return interval > 0 and state.global_step % interval == 0
|
||||
|
||||
|
||||
class AuxFreeMoEPlugin(BasePlugin):
|
||||
"""Plugin that enables aux-loss-free routing when configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._handles: list = []
|
||||
self._shim: Optional[AuxFreeShim] = None
|
||||
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs"
|
||||
|
||||
def post_model_build(self, cfg, model):
|
||||
# Enable only when explicitly requested
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return
|
||||
|
||||
# Be conservative — skip known native aux-free families
|
||||
native_auxfree = getattr(
|
||||
getattr(model, "config", object()), "model_type", ""
|
||||
) in (
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
)
|
||||
if native_auxfree:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
|
||||
)
|
||||
return
|
||||
|
||||
# Build aux-free state and shim
|
||||
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
|
||||
momentum = (
|
||||
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
|
||||
)
|
||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||
af_cfg = AuxFreeConfig(
|
||||
rate=rate,
|
||||
momentum=momentum,
|
||||
bias_cap=bias_cap,
|
||||
warmup_steps=warmup,
|
||||
sync_group=sync_group,
|
||||
)
|
||||
|
||||
# Discover layers to count the number and experts for state sizing
|
||||
adapters: list[BaseMoEAdapter] = [
|
||||
MixtralAdapter(),
|
||||
Qwen3Adapter(),
|
||||
Qwen35MoeAdapter(),
|
||||
BailingAdapter(),
|
||||
Llama4Adapter(),
|
||||
]
|
||||
|
||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||
n_layers = 0
|
||||
n_experts = None
|
||||
for _m in model.modules():
|
||||
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||
device = next(model.parameters(), torch.tensor(0)).device
|
||||
if n_layers <= 0:
|
||||
n_layers = 1
|
||||
if n_experts is None:
|
||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||
n_experts = 2
|
||||
state = AuxFreeState(
|
||||
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
|
||||
)
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
ep_group = None
|
||||
if sync_group == "ep":
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
ep_group = self._resolve_ep_group(cfg)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
||||
)
|
||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
||||
|
||||
# Discover and prepare layers (attach per-layer buffers)
|
||||
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||
|
||||
LOG.info(
|
||||
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
||||
)
|
||||
|
||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
|
||||
)
|
||||
return None
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
|
||||
)
|
||||
return None
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
return None
|
||||
if ep_size == world:
|
||||
return dist.group.WORLD
|
||||
|
||||
rank = dist.get_rank()
|
||||
# All ranks must collectively create all EP subgroups in the same order
|
||||
# to avoid deadlocks (dist.new_group is a collective operation).
|
||||
world_size = world
|
||||
my_group = None
|
||||
for group_start in range(0, world_size, ep_size):
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
if ranks not in self._ep_group_cache:
|
||||
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||
if rank in ranks:
|
||||
my_group = self._ep_group_cache[ranks]
|
||||
return my_group
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return []
|
||||
if self._shim is None:
|
||||
return []
|
||||
# gather concrete layer modules from handles
|
||||
layers = [h.layer for h in self._handles]
|
||||
cb = MoeAuxFreeBiasUpdateCallback(
|
||||
self._shim,
|
||||
layers,
|
||||
trainer,
|
||||
)
|
||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||
return [cb]
|
||||
@@ -240,7 +240,16 @@ def _softmax_topk_route(
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
# Aux-free bias: biased selection, unbiased weights
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = routing_weights + afb_bias
|
||||
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
|
||||
routing_weights = routing_weights.gather(1, selected_experts)
|
||||
_accumulate_afb_counts(moe_block, selected_experts)
|
||||
else:
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
|
||||
else:
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = scores_for_choice + afb_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
if n_group > 1:
|
||||
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
if getattr(moe_block, "norm_topk_prob", True):
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||
"""Accumulate per-expert token counts for aux-free bias updates."""
|
||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||
if afb_counts is None:
|
||||
return
|
||||
flat_idx = topk_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
|
||||
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared expert helpers
|
||||
# =============================================================================
|
||||
|
||||
@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
|
||||
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
|
||||
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
|
||||
as buffers. The routing functions transparently inject the bias into expert
|
||||
*selection* (biased topk) while keeping mixture *weights* from unbiased
|
||||
scores, then accumulate per-expert token counts for the post-step bias update.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -101,17 +107,25 @@ def softmax_topk_routing(
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Aux-free bias: biased selection, unbiased weights
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
scores_for_choice = router_probs
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = router_probs + afb_bias
|
||||
|
||||
# Select top-k experts per token
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
|
||||
|
||||
# When aux-free bias is active, gather unbiased weights and accumulate counts
|
||||
if afb_bias is not None:
|
||||
top_values = router_probs.gather(1, top_indices)
|
||||
_accumulate_afb_counts(moe_block, top_indices)
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
# top_values = top_values.to(router_probs.dtype)
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Aux-free bias: inject before group selection / topk
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
scores_for_choice = router_probs
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = router_probs + afb_bias
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
@@ -159,11 +177,17 @@ def softmax_group_topk_routing(
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = scores_for_choice.masked_fill(
|
||||
~score_mask.bool(), -float("inf")
|
||||
)
|
||||
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
@@ -233,6 +257,11 @@ def sigmoid_topk_routing(
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = scores_for_choice + afb_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
@@ -248,7 +277,9 @@ def sigmoid_topk_routing(
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = scores_for_choice.masked_fill(
|
||||
~score_mask.bool(), -float("inf")
|
||||
)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
@@ -256,6 +287,10 @@ def sigmoid_topk_routing(
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
@@ -276,3 +311,21 @@ def sigmoid_topk_routing(
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||
"""Accumulate per-expert token counts for the aux-free bias update.
|
||||
|
||||
Called when ``moe_block._afb_bias`` is present (registered by the
|
||||
``aux_free_router`` plugin). The counts are later consumed by the
|
||||
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
||||
"""
|
||||
if hasattr(moe_block, "training") and not moe_block.training:
|
||||
return
|
||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||
if afb_counts is None:
|
||||
return
|
||||
num_experts = afb_counts.numel()
|
||||
flat_idx = topk_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=num_experts)
|
||||
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||
|
||||
@@ -299,6 +299,7 @@ def validate_config(
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
if cfg.plugins:
|
||||
prepare_plugins(cfg)
|
||||
(
|
||||
AxolotlConfigWCapabilities,
|
||||
AxolotlInputConfig,
|
||||
|
||||
@@ -836,6 +836,12 @@ class AxolotlInputConfig(
|
||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||
},
|
||||
)
|
||||
expert_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
||||
},
|
||||
)
|
||||
special_tokens: SpecialTokensConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1386,6 +1386,14 @@ class ComplexValidationMixin:
|
||||
self.tensor_parallel_size = 1
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_expert_parallel_size(self):
|
||||
if not getattr(self, "expert_parallel_size", None):
|
||||
self.expert_parallel_size = 1
|
||||
elif self.expert_parallel_size < 1:
|
||||
raise ValueError("expert_parallel_size must be >= 1")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_size(self):
|
||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||
|
||||
75
tests/e2e/test_llama4_moe_aux_free.py
Normal file
75
tests/e2e/test_llama4_moe_aux_free.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
E2E smoke test for Llama 4 aux-loss-free routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestLlama4MoeAuxFree(unittest.TestCase):
|
||||
"""Smoke test to ensure aux-free plugin patches Llama 4 MoE correctly."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_llama4_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "yujiepan/llama-4-tiny-random",
|
||||
"tokenizer_config": "yujiepan/llama-4-tiny-random",
|
||||
"trust_remote_code": False,
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||
assert patched is not None, (
|
||||
"Llama 4 MoE layer was not patched by aux-free plugin"
|
||||
)
|
||||
assert patched._afb_bias.ndim == 1
|
||||
assert patched._afb_counts.ndim == 1
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
79
tests/e2e/test_moe_aux_free.py
Normal file
79
tests/e2e/test_moe_aux_free.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestMoeAuxFree(unittest.TestCase):
|
||||
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_mixtral_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin and toggles
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
# Inspect model modules for a patched MoE layer
|
||||
patched = None
|
||||
for m in model.modules():
|
||||
if hasattr(m, "_afb_patched") and m._afb_patched is True:
|
||||
patched = m
|
||||
break
|
||||
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
|
||||
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
|
||||
# ensure counts buffer got reset by callback (best effort)
|
||||
assert torch.all(patched._afb_counts == 0)
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
91
tests/e2e/test_moe_aux_parity.py
Normal file
91
tests/e2e/test_moe_aux_parity.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
|
||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
|
||||
def _last_logged_loss(trainer) -> float | None:
|
||||
# Scan log_history for the most recent entry with a 'loss' key
|
||||
for entry in reversed(trainer.state.log_history):
|
||||
if isinstance(entry, dict) and "loss" in entry:
|
||||
return float(entry["loss"])
|
||||
return None
|
||||
|
||||
|
||||
class TestMoeAuxParity(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
|
||||
base_cfg = {
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 8,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
"seed": 42,
|
||||
"logging_steps": 1,
|
||||
}
|
||||
|
||||
# Baseline: aux-loss (gshard)
|
||||
cfg0 = DictDefault(dict(base_cfg))
|
||||
cfg0.output_dir = f"{temp_dir}/baseline"
|
||||
cfg0 = validate_config(cfg0)
|
||||
normalize_config(cfg0)
|
||||
# baseline uses default aux-loss routing; no plugin registration
|
||||
dataset_meta0 = load_datasets(cfg=cfg0)
|
||||
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
|
||||
loss0 = _last_logged_loss(trainer0)
|
||||
assert loss0 is not None
|
||||
|
||||
# Release baseline resources before starting aux-free run
|
||||
del model0, trainer0, dataset_meta0
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Aux-free: plugin + noaux_tc
|
||||
cfg1 = DictDefault(dict(base_cfg))
|
||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||
cfg1.plugins = [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
]
|
||||
cfg1.moe_balance_type = "noaux_tc"
|
||||
cfg1.moe_update_rate = 0.01
|
||||
cfg1.moe_update_momentum = 0.9
|
||||
cfg1.moe_bias_cap = 2.0
|
||||
prepare_plugins(cfg1)
|
||||
cfg1 = validate_config(cfg1)
|
||||
normalize_config(cfg1)
|
||||
dataset_meta1 = load_datasets(cfg=cfg1)
|
||||
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
|
||||
loss1 = _last_logged_loss(trainer1)
|
||||
assert loss1 is not None
|
||||
|
||||
# Assert aux-free loss is within 10% of aux-loss baseline
|
||||
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"
|
||||
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestQwen3MoeAuxFree(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin and toggles
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
# check that at least one sparse MoE block has been patched
|
||||
found = False
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
|
||||
m, "_afb_patched"
|
||||
):
|
||||
assert m._afb_patched is True
|
||||
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||
found = True
|
||||
break
|
||||
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
74
tests/e2e/test_ring_moe_aux_free.py
Normal file
74
tests/e2e/test_ring_moe_aux_free.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
E2E smoke test for Ring 2.0 aux-loss-free routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestRingMoeAuxFree(unittest.TestCase):
|
||||
"""Smoke test to ensure aux-free plugin patches Ring Mini 2.0 correctly."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_ring_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "yujiepan/ring-tiny-random",
|
||||
"tokenizer_config": "yujiepan/ring-tiny-random",
|
||||
"trust_remote_code": True,
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin config
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||
assert patched is not None, "Ring MoE layer was not patched by aux-free plugin"
|
||||
assert patched._afb_bias.ndim == 1
|
||||
assert patched._afb_counts.ndim == 1
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -12,7 +12,11 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
|
||||
try:
|
||||
from tbparse import SummaryReader
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
SummaryReader = None
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -185,6 +189,10 @@ def check_tensorboard(
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
"""
|
||||
if SummaryReader is None:
|
||||
raise unittest.SkipTest(
|
||||
"tbparse is not installed; skipping tensorboard assertions"
|
||||
)
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
|
||||
666
tests/unit/test_aux_free_adapters.py
Normal file
666
tests/unit/test_aux_free_adapters.py
Normal file
@@ -0,0 +1,666 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||
|
||||
|
||||
def _cfg(**overrides):
|
||||
defaults = dict(
|
||||
moe_balance_type="noaux_tc",
|
||||
moe_update_rate=0.1,
|
||||
moe_update_momentum=0.9,
|
||||
moe_bias_cap=2.0,
|
||||
moe_afb_warmup_steps=0,
|
||||
moe_bias_sync_group="world",
|
||||
expert_parallel_size=1,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
def _load_bailing_modules():
|
||||
repo_dir = snapshot_download(
|
||||
repo_id="inclusionAI/Ring-mini-2.0",
|
||||
allow_patterns=[
|
||||
"configuration_bailing_moe_v2.py",
|
||||
"modeling_bailing_moe_v2.py",
|
||||
"__init__.py",
|
||||
],
|
||||
)
|
||||
repo = Path(repo_dir)
|
||||
config_path = repo / "configuration_bailing_moe_v2.py"
|
||||
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
||||
|
||||
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
||||
if config_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[config_name] = module
|
||||
sys.modules["configuration_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
config_module = sys.modules[config_name]
|
||||
|
||||
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
||||
if modeling_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[modeling_name] = module
|
||||
sys.modules["modeling_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
modeling_module = sys.modules[modeling_name]
|
||||
|
||||
BailingMoeV2Config = config_module.BailingMoeV2Config
|
||||
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
||||
|
||||
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
||||
|
||||
|
||||
def _build_bailing_model():
|
||||
BailingConfig, BailingBlock = _load_bailing_modules()
|
||||
config = BailingConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
moe_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_shared_experts=None,
|
||||
num_experts_per_tok=2,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
block = BailingBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super().__init__()
|
||||
self.block = layer
|
||||
self.config = SimpleNamespace(model_type="bailing_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.block(hidden_states)
|
||||
|
||||
return DummyModel(block), block
|
||||
|
||||
|
||||
def _build_llama4_model():
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
# Build config without __post_init__ validation (works around a
|
||||
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
||||
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
||||
config.__dict__.update(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_experts_per_tok=2,
|
||||
num_hidden_layers=2,
|
||||
hidden_act="silu",
|
||||
layer_types=None,
|
||||
)
|
||||
layer = Llama4TextMoe(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="llama4")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_mixtral_model():
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
layer = MixtralSparseMoeBlock(config)
|
||||
layer.config = config
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="mixtral")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_qwen35_moe_model():
|
||||
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
||||
Qwen3_5MoeTextConfig,
|
||||
)
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeSparseMoeBlock,
|
||||
)
|
||||
|
||||
config = Qwen3_5MoeTextConfig(
|
||||
hidden_size=16,
|
||||
moe_intermediate_size=32,
|
||||
shared_expert_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
layer = Qwen3_5MoeSparseMoeBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
if state is None:
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
if control is None:
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
|
||||
class DummyTrainer:
|
||||
def __init__(self, state_obj, control_obj):
|
||||
self.state = state_obj
|
||||
self.control = control_obj
|
||||
|
||||
def log(self, logs):
|
||||
output = dict(logs)
|
||||
output["step"] = self.state.global_step
|
||||
self.state.log_history.append(output)
|
||||
self.control.should_log = True
|
||||
|
||||
dummy_trainer = DummyTrainer(state, control)
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
return state, control
|
||||
|
||||
|
||||
class TestAuxFreeAdapters(unittest.TestCase):
|
||||
def test_bailing_adapter_updates_counts_and_bias(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(block, "_afb_bias"))
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
||||
)
|
||||
|
||||
def test_llama4_adapter_biases_router_selection(self):
|
||||
model, layer = _build_llama4_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
hidden = torch.randn(2, 4, layer.hidden_dim)
|
||||
layer(hidden)
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_bias_warmup_respected(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_afb_warmup_steps=2)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
def _step():
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
_step()
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
# Third step exceeds warmup -> bias should update.
|
||||
_step()
|
||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||
|
||||
def test_mixtral_adapter_patches_router_not_forward(self):
|
||||
"""Verify that aux-free patches the router (gate) only, and the
|
||||
v5 block forward signature (single tensor return) is preserved."""
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Gate should be patched, not the block forward
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
|
||||
# v5 block forward returns a single tensor (not a tuple with logits)
|
||||
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
def test_mixtral_adapter_bias_affects_selection(self):
|
||||
"""When bias is large for one expert, it should be selected more often."""
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Set a large bias for expert 0 to force its selection
|
||||
layer._afb_bias.zero_()
|
||||
layer._afb_bias[0] = 10.0
|
||||
|
||||
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
||||
num_tokens = 2 * 8 # batch * seq
|
||||
layer(hidden)
|
||||
|
||||
# With top_k=2, expert 0 should appear in every token's selection
|
||||
# (once per token = num_tokens counts, not num_tokens * top_k)
|
||||
counts = layer._afb_counts.clone()
|
||||
self.assertEqual(
|
||||
int(counts[0].item()),
|
||||
num_tokens,
|
||||
msg="Expert 0 should be selected for every token when heavily biased",
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
||||
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
||||
output includes shared expert contribution."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Gate should be patched
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
# Shared expert should be unmodified
|
||||
self.assertTrue(hasattr(layer, "shared_expert"))
|
||||
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
||||
|
||||
# Forward should return a single tensor (shared + routed)
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 3, hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
def test_qwen35_moe_adapter_bias_updates(self):
|
||||
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 4, hidden_size)
|
||||
layer(hidden)
|
||||
|
||||
# Bias should start at zero
|
||||
self.assertTrue(
|
||||
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
||||
)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# After callback: counts reset, EMA updated, bias updated
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_model_type_matching(self):
|
||||
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
||||
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
||||
|
||||
adapter = Qwen35MoeAdapter()
|
||||
|
||||
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
||||
model_text = SimpleNamespace(
|
||||
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
||||
)
|
||||
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
||||
|
||||
self.assertTrue(adapter.matches(model_moe))
|
||||
self.assertTrue(adapter.matches(model_text))
|
||||
self.assertFalse(adapter.matches(model_other))
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
self.skipTest(
|
||||
"Cannot safely test deferred EP group resolution when a process group is already initialized"
|
||||
)
|
||||
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertIsNotNone(plugin._shim)
|
||||
self.assertIsNone(plugin._shim.ep_group)
|
||||
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
dist.init_process_group(
|
||||
backend="gloo", init_method=init_method, world_size=1, rank=0
|
||||
)
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
_run_callback(
|
||||
plugin,
|
||||
cfg,
|
||||
args=SimpleNamespace(logging_steps=1),
|
||||
state=SimpleNamespace(global_step=1, log_history=[]),
|
||||
control=SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
),
|
||||
)
|
||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
os.unlink(tmp_init.name)
|
||||
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
layer(hidden)
|
||||
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
||||
|
||||
self.assertTrue(control.should_log)
|
||||
self.assertTrue(state.log_history)
|
||||
telemetry = state.log_history[-1]
|
||||
self.assertEqual(telemetry["step"], state.global_step)
|
||||
self.assertIn("moe_afb/l0_load_min", telemetry)
|
||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||
|
||||
def test_get_num_experts_v5_attribute_paths(self):
|
||||
"""Verify get_num_experts works with v5 attribute layout where
|
||||
num_experts is on gate/experts sub-modules, not the block."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
||||
block = SimpleNamespace(
|
||||
gate=SimpleNamespace(num_experts=8),
|
||||
experts=SimpleNamespace(num_experts=8),
|
||||
)
|
||||
self.assertEqual(adapter.get_num_experts(block), 8)
|
||||
|
||||
# Also works when num_experts is directly on block
|
||||
block2 = SimpleNamespace(num_experts=4)
|
||||
self.assertEqual(adapter.get_num_experts(block2), 4)
|
||||
|
||||
|
||||
class TestAuxFreeKernelComposition(unittest.TestCase):
|
||||
"""Tests that aux-free bias composes correctly with kernel routing."""
|
||||
|
||||
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
||||
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
# Build a mock MoE block with gate attributes
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline: no bias
|
||||
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
self.assertEqual(scores_base.shape[0], T * top_k)
|
||||
|
||||
# Now register aux-free buffers and set heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# Expert 0 should be selected for every token
|
||||
self.assertTrue(
|
||||
(exp_biased == 0).any(),
|
||||
"Expert 0 should appear in selections when heavily biased",
|
||||
)
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
# Total counts should equal T * top_k
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_sonicmoe_routing_without_bias_unchanged(self):
|
||||
"""Without _afb_bias, routing should produce identical results."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(4, hidden_dim)
|
||||
|
||||
# Without _afb_bias attribute
|
||||
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
# With _afb_bias = zeros (should be equivalent)
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
torch.testing.assert_close(scores1, scores2)
|
||||
torch.testing.assert_close(exp1, exp2)
|
||||
|
||||
@unittest.skipUnless(
|
||||
importlib_util.find_spec("triton") is not None,
|
||||
"triton not installed (required by scattermoe)",
|
||||
)
|
||||
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
||||
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
gate_weight = torch.randn(num_experts, hidden_dim)
|
||||
base_gate = SimpleNamespace(
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
norm_topk_prob=True,
|
||||
weight=gate_weight,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline without bias
|
||||
w_base, e_base, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# With heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
w_biased, e_biased, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# Expert 0 should appear in all selections
|
||||
self.assertTrue((e_biased == 0).any())
|
||||
# Counts accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_kernel_routing_skips_router_patch(self):
|
||||
"""When a kernel backend has patched the block class, the adapter
|
||||
should skip patching the router (buffers are still registered)."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True # SonicMoE marker
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16) # placeholder
|
||||
|
||||
layer = PatchedBlock()
|
||||
self.assertTrue(adapter.uses_kernel_routing(layer))
|
||||
|
||||
# Gate should NOT be patched (kernel handles routing)
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
def test_adapter_buffers_registered_even_with_kernel(self):
|
||||
"""Even when kernel routing is active, aux-free buffers must be
|
||||
registered on the MoE block so the kernel routing can find them."""
|
||||
from axolotl.integrations.aux_free_router.adapters import (
|
||||
LayerHandle,
|
||||
MixtralAdapter,
|
||||
)
|
||||
from axolotl.integrations.aux_free_router.core import (
|
||||
AuxFreeConfig,
|
||||
AuxFreeShim,
|
||||
AuxFreeState,
|
||||
)
|
||||
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16)
|
||||
|
||||
layer = PatchedBlock()
|
||||
adapter = MixtralAdapter()
|
||||
cfg = AuxFreeConfig()
|
||||
state = AuxFreeState(
|
||||
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
||||
)
|
||||
shim = AuxFreeShim(state=state)
|
||||
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
||||
|
||||
adapter.prepare(layer, handle, shim)
|
||||
|
||||
# Buffers should be registered for kernel routing to use
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
self.assertTrue(hasattr(layer, "_afb_counts"))
|
||||
self.assertTrue(hasattr(layer, "_afb_ema"))
|
||||
# But gate should NOT be patched
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user