Compare commits
4 Commits
textui
...
949cdf01eb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
949cdf01eb | ||
|
|
a0019021dd | ||
|
|
2af7475fdf | ||
|
|
3e4688289c |
40
src/axolotl/integrations/aux_free_router/README.md
Normal file
40
src/axolotl/integrations/aux_free_router/README.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Aux-Loss-Free MoE Router Plugin
|
||||||
|
|
||||||
|
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
|
||||||
|
|
||||||
|
Summary
|
||||||
|
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
|
||||||
|
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
|
||||||
|
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
|
||||||
|
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
|
||||||
|
|
||||||
|
Enable
|
||||||
|
- Add the plugin to your YAML, then set the aux-free toggle:
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
|
||||||
|
|
||||||
|
moe_balance_type: noaux_tc
|
||||||
|
moe_update_rate: 0.01 # default if unset
|
||||||
|
moe_update_momentum: 0.9 # default if unset
|
||||||
|
moe_bias_cap: 2.0 # default if unset
|
||||||
|
moe_afb_warmup_steps: 100 # optional
|
||||||
|
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
||||||
|
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
||||||
|
|
||||||
|
Config keys
|
||||||
|
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
||||||
|
- moe_update_rate: bias update rate (gamma). Default: 0.01.
|
||||||
|
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
|
||||||
|
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
||||||
|
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||||
|
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||||
|
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||||
|
|
||||||
|
Compatibility
|
||||||
|
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||||
|
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||||
|
- Telemetry: future updates will log per-expert loads and bias magnitudes.
|
||||||
2
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
2
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
"""Aux-loss-free (AFB) MoE router integration package."""
|
||||||
|
|
||||||
317
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
317
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .core import AuxFreeShim
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayerHandle:
|
||||||
|
layer: nn.Module
|
||||||
|
layer_idx: int
|
||||||
|
num_experts: int
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMoEAdapter:
|
||||||
|
"""Base adapter that discovers MoE layers and wraps their forward.
|
||||||
|
|
||||||
|
Concrete adapters should implement discovery and per-layer attribute extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
family: str = "generic"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
||||||
|
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
|
||||||
|
|
||||||
|
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
||||||
|
return int(getattr(moe_layer, "num_experts"))
|
||||||
|
|
||||||
|
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||||
|
# Best-effort: zero router aux loss coef if present
|
||||||
|
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||||
|
try:
|
||||||
|
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
|
||||||
|
except Exception: # pragma: no cover - non-critical
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||||
|
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
||||||
|
if not hasattr(moe_layer, "_afb_bias"):
|
||||||
|
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
||||||
|
if not hasattr(moe_layer, "_afb_counts"):
|
||||||
|
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
|
||||||
|
if not hasattr(moe_layer, "_afb_ema"):
|
||||||
|
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
||||||
|
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||||
|
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||||
|
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||||
|
|
||||||
|
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||||
|
"""Attach per-layer buffers and mark as aux-free enabled."""
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
self._patch_forward_with_aux_free(moe_layer)
|
||||||
|
|
||||||
|
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
||||||
|
"""Replace the layer's forward with an aux-free gating version.
|
||||||
|
|
||||||
|
Assumes the layer exposes attributes:
|
||||||
|
- gate: linear router projecting hidden to num_experts
|
||||||
|
- num_experts: int
|
||||||
|
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
|
||||||
|
"""
|
||||||
|
if getattr(moe_layer, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
|
||||||
|
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
|
||||||
|
return
|
||||||
|
|
||||||
|
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||||
|
# hidden_states: (B, T, H)
|
||||||
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
|
hs = hidden_states.view(-1, hdim)
|
||||||
|
logits = self.gate(hs)
|
||||||
|
# selection uses biased logits; weights from unbiased logits
|
||||||
|
bias = getattr(self, "_afb_bias")
|
||||||
|
top_k = int(getattr(self, "_afb_top_k", 2))
|
||||||
|
biased = logits + bias # broadcast over tokens
|
||||||
|
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
|
||||||
|
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||||
|
weights = torch.softmax(chosen_logits.float(), dim=-1)
|
||||||
|
weights = weights.to(hs.dtype)
|
||||||
|
|
||||||
|
# accumulate counts for bias update callback
|
||||||
|
flat_idx = topk_idx.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||||
|
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
|
||||||
|
|
||||||
|
# dispatch tokens to experts
|
||||||
|
hs_rep = hs.repeat_interleave(top_k, dim=0)
|
||||||
|
y = torch.empty_like(hs_rep)
|
||||||
|
for eid in range(int(self.num_experts)):
|
||||||
|
mask = flat_idx == eid
|
||||||
|
if mask.any():
|
||||||
|
y[mask] = self.experts[eid](hs_rep[mask])
|
||||||
|
|
||||||
|
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
|
||||||
|
out = y.view(bsz, seqlen, hdim)
|
||||||
|
return (out, logits)
|
||||||
|
|
||||||
|
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||||
|
setattr(moe_layer, "_afb_patched", True)
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralAdapter(BaseMoEAdapter):
|
||||||
|
family = "mixtral"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||||
|
|
||||||
|
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
self._patch_mixtral_forward(moe_layer, shim)
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||||
|
yield m
|
||||||
|
|
||||||
|
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
|
||||||
|
if getattr(moe_layer, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
shim_ref = shim
|
||||||
|
|
||||||
|
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
if self.training and getattr(self, "jitter_noise", 0) > 0:
|
||||||
|
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
||||||
|
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
||||||
|
)
|
||||||
|
flat_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
router_logits = self.gate(flat_states)
|
||||||
|
|
||||||
|
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
|
||||||
|
top_k = int(getattr(self, "_afb_top_k", self.top_k))
|
||||||
|
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
|
||||||
|
routing_weights = routing_weights.to(flat_states.dtype)
|
||||||
|
|
||||||
|
flat_idx = selected_experts.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||||
|
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
|
||||||
|
|
||||||
|
final_hidden_states = torch.zeros(
|
||||||
|
(batch_size * sequence_length, hidden_dim),
|
||||||
|
dtype=flat_states.dtype,
|
||||||
|
device=flat_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||||
|
for expert_idx in expert_hit:
|
||||||
|
expert_layer = self.experts[expert_idx]
|
||||||
|
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||||
|
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
||||||
|
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||||
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
||||||
|
|
||||||
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
|
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||||
|
setattr(moe_layer, "_afb_patched", True)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3Adapter(MixtralAdapter):
|
||||||
|
family = "qwen3_moe"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
|
||||||
|
|
||||||
|
|
||||||
|
class BailingAdapter(BaseMoEAdapter):
|
||||||
|
family = "bailing_moe"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
model_type = getattr(getattr(model, "config", object()), "model_type", "")
|
||||||
|
return model_type in ("bailing_moe", "bailing_moe_v2")
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
||||||
|
yield m
|
||||||
|
|
||||||
|
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||||
|
if hasattr(moe_layer, "num_experts"):
|
||||||
|
return int(getattr(moe_layer, "num_experts"))
|
||||||
|
cfg = getattr(moe_layer, "config", None)
|
||||||
|
return int(getattr(cfg, "num_experts"))
|
||||||
|
|
||||||
|
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
self._patch_bailing_gate(moe_layer)
|
||||||
|
|
||||||
|
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||||
|
gate = getattr(moe_layer, "gate", None)
|
||||||
|
if gate is None:
|
||||||
|
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
||||||
|
return
|
||||||
|
if getattr(gate, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
||||||
|
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
logits = F.linear(flat.float(), self.weight.float())
|
||||||
|
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||||
|
bias = getattr(moe_layer, "_afb_bias")
|
||||||
|
biased_scores = scores_unbiased + bias
|
||||||
|
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
||||||
|
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||||
|
if self.top_k > 1:
|
||||||
|
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||||
|
weights = weights / denom
|
||||||
|
weights = weights * self.routed_scaling_factor
|
||||||
|
|
||||||
|
flat_topk = topk_idx.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||||
|
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||||
|
|
||||||
|
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||||
|
|
||||||
|
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||||
|
setattr(gate, "_afb_patched", True)
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4Adapter(BaseMoEAdapter):
|
||||||
|
family = "llama4"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "Llama4TextMoe":
|
||||||
|
yield m
|
||||||
|
|
||||||
|
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
self._patch_llama4_router(moe_layer)
|
||||||
|
|
||||||
|
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||||
|
router = getattr(moe_layer, "router", None)
|
||||||
|
if router is None:
|
||||||
|
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
||||||
|
return
|
||||||
|
if getattr(router, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||||
|
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
router_logits = F.linear(flat, self.weight, self.bias)
|
||||||
|
bias = getattr(moe_layer, "_afb_bias")
|
||||||
|
biased_logits = router_logits + bias
|
||||||
|
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||||
|
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||||
|
router_scores = torch.full_like(router_logits, float("-inf"))
|
||||||
|
router_scores.scatter_(1, router_indices, unbiased_top)
|
||||||
|
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||||
|
|
||||||
|
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||||
|
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||||
|
|
||||||
|
return router_scores, router_logits
|
||||||
|
|
||||||
|
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||||
|
setattr(router, "_afb_patched", True)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
|
||||||
|
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||||
|
|
||||||
|
Returns a list of layer handles for later routing patching and updates.
|
||||||
|
"""
|
||||||
|
handles: list[LayerHandle] = []
|
||||||
|
adapter: Optional[BaseMoEAdapter] = None
|
||||||
|
for a in adapters:
|
||||||
|
if a.matches(model):
|
||||||
|
adapter = a
|
||||||
|
break
|
||||||
|
|
||||||
|
if adapter is None:
|
||||||
|
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
|
||||||
|
return handles
|
||||||
|
|
||||||
|
# disable aux loss at model level if possible
|
||||||
|
adapter.disable_aux_loss(getattr(model, "config", model))
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for layer in adapter.find_moe_layers(model):
|
||||||
|
try:
|
||||||
|
top_k = adapter.get_top_k(layer)
|
||||||
|
nE = adapter.get_num_experts(layer)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||||
|
adapter.prepare(layer, handle, shim)
|
||||||
|
handles.append(handle)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
|
||||||
|
return handles
|
||||||
150
src/axolotl/integrations/aux_free_router/core.py
Normal file
150
src/axolotl/integrations/aux_free_router/core.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuxFreeConfig:
|
||||||
|
rate: float = 0.01
|
||||||
|
momentum: float = 0.9
|
||||||
|
bias_cap: float = 2.0
|
||||||
|
warmup_steps: int = 0
|
||||||
|
sync_group: str = "world" # or "ep"
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeState:
|
||||||
|
"""Holds per-layer bias and EMA load buffers."""
|
||||||
|
|
||||||
|
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
|
||||||
|
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||||
|
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||||
|
self.cfg = cfg
|
||||||
|
self.steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeShim:
|
||||||
|
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state: AuxFreeState,
|
||||||
|
ep_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
ep_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.state = state
|
||||||
|
self.ep_group = ep_group
|
||||||
|
self._ep_size = ep_size
|
||||||
|
self._ep_group_pending = (
|
||||||
|
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||||
|
)
|
||||||
|
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||||
|
module = self._layer_modules.get(layer_idx)
|
||||||
|
if module is not None and hasattr(module, "_afb_bias"):
|
||||||
|
b = getattr(module, "_afb_bias")
|
||||||
|
else:
|
||||||
|
b = self.state.bias[layer_idx]
|
||||||
|
biased = logits + b # bias is a buffer
|
||||||
|
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||||
|
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||||
|
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||||
|
return topk_idx, weights
|
||||||
|
|
||||||
|
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||||
|
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||||
|
self._layer_modules[layer_idx] = module
|
||||||
|
bias = getattr(module, "_afb_bias")
|
||||||
|
ema = getattr(module, "_afb_ema")
|
||||||
|
# Keep state views pointing to the same tensors to avoid drift.
|
||||||
|
if layer_idx < len(self.state.bias):
|
||||||
|
self.state.bias[layer_idx] = bias
|
||||||
|
if layer_idx < len(self.state.ema_load):
|
||||||
|
self.state.ema_load[layer_idx] = ema
|
||||||
|
|
||||||
|
def begin_step(self) -> None:
|
||||||
|
"""Call once per optimizer step before per-layer updates."""
|
||||||
|
self.state.steps += 1
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||||
|
self._maybe_init_ep_group()
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
return counts
|
||||||
|
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
||||||
|
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
|
||||||
|
return counts
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
||||||
|
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
||||||
|
cfg = self.state.cfg
|
||||||
|
if self.state.steps <= cfg.warmup_steps:
|
||||||
|
return
|
||||||
|
|
||||||
|
nE = step_counts.numel()
|
||||||
|
if tokens_seen <= 0:
|
||||||
|
return
|
||||||
|
module = self._layer_modules.get(layer_idx)
|
||||||
|
if module is not None and hasattr(module, "_afb_ema"):
|
||||||
|
ema = getattr(module, "_afb_ema")
|
||||||
|
bias = getattr(module, "_afb_bias")
|
||||||
|
else:
|
||||||
|
ema = self.state.ema_load[layer_idx]
|
||||||
|
bias = self.state.bias[layer_idx]
|
||||||
|
counts = step_counts.to(ema.device)
|
||||||
|
freq = counts.float() / float(tokens_seen)
|
||||||
|
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||||
|
target = 1.0 / float(nE)
|
||||||
|
delta = cfg.rate * (target - ema)
|
||||||
|
# optional mean-centering to keep sum(bias) ~ 0
|
||||||
|
delta = delta - delta.mean()
|
||||||
|
bias.add_(delta)
|
||||||
|
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||||
|
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||||
|
|
||||||
|
def _maybe_init_ep_group(self) -> None:
|
||||||
|
if not self._ep_group_pending:
|
||||||
|
return
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
return
|
||||||
|
ep_size = self._ep_size
|
||||||
|
if not ep_size or ep_size <= 1:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
||||||
|
)
|
||||||
|
self.ep_group = dist.group.WORLD
|
||||||
|
self._ep_group_pending = False
|
||||||
|
return
|
||||||
|
world = dist.get_world_size()
|
||||||
|
if world % ep_size != 0:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
||||||
|
ep_size,
|
||||||
|
world,
|
||||||
|
)
|
||||||
|
self.ep_group = dist.group.WORLD
|
||||||
|
self._ep_group_pending = False
|
||||||
|
return
|
||||||
|
if ep_size == world:
|
||||||
|
self.ep_group = dist.group.WORLD
|
||||||
|
else:
|
||||||
|
rank = dist.get_rank()
|
||||||
|
group_start = (rank // ep_size) * ep_size
|
||||||
|
ranks = tuple(range(group_start, group_start + ep_size))
|
||||||
|
self.ep_group = dist.new_group(ranks)
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
||||||
|
ep_size,
|
||||||
|
dist.get_world_size(),
|
||||||
|
)
|
||||||
|
self._ep_group_pending = False
|
||||||
175
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
175
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Aux-loss-free MoE Router Plugin for Axolotl.
|
||||||
|
|
||||||
|
This plugin wires an aux-free gating option into compatible MoE models using
|
||||||
|
unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .adapters import (
|
||||||
|
BailingAdapter,
|
||||||
|
BaseMoEAdapter,
|
||||||
|
Llama4Adapter,
|
||||||
|
MixtralAdapter,
|
||||||
|
Qwen3Adapter,
|
||||||
|
discover_and_prepare_layers,
|
||||||
|
)
|
||||||
|
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||||
|
"""Post-step callback to update aux-free biases from accumulated expert counts.
|
||||||
|
|
||||||
|
Note: The current revision expects per-layer counts to be accumulated on each
|
||||||
|
MoE layer as a buffer named `_afb_counts` during forward (to be added with
|
||||||
|
routing patches in a follow-up).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
|
||||||
|
self.shim = shim
|
||||||
|
self.layer_modules = layer_modules
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||||
|
# Iterate prepared MoE layers and apply the bias update rule.
|
||||||
|
self.shim.begin_step()
|
||||||
|
for layer in self.layer_modules:
|
||||||
|
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
|
||||||
|
continue
|
||||||
|
counts = getattr(layer, "_afb_counts")
|
||||||
|
if counts is None:
|
||||||
|
continue
|
||||||
|
counts = self.shim.all_reduce_counts(counts)
|
||||||
|
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
||||||
|
if layer_idx is None:
|
||||||
|
counts.zero_()
|
||||||
|
continue
|
||||||
|
bias = getattr(layer, "_afb_bias")
|
||||||
|
counts_for_update = counts.to(bias.device)
|
||||||
|
tokens_seen = int(counts_for_update.sum().item())
|
||||||
|
# local layer-state EMA and bias update
|
||||||
|
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||||
|
# reset step counts
|
||||||
|
counts.zero_()
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeMoEPlugin(BasePlugin):
|
||||||
|
"""Plugin that enables aux-loss-free routing when configured."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._handles: list = []
|
||||||
|
self._shim: Optional[AuxFreeShim] = None
|
||||||
|
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||||
|
|
||||||
|
def post_model_build(self, cfg, model):
|
||||||
|
# Enable only when explicitly requested
|
||||||
|
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Be conservative — skip known native aux-free families
|
||||||
|
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||||
|
"deepseek_v3",
|
||||||
|
"glm4_moe",
|
||||||
|
)
|
||||||
|
if native_auxfree:
|
||||||
|
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build aux-free state and shim
|
||||||
|
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
|
||||||
|
momentum = (
|
||||||
|
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
|
||||||
|
)
|
||||||
|
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||||
|
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||||
|
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||||
|
af_cfg = AuxFreeConfig(
|
||||||
|
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
|
||||||
|
)
|
||||||
|
|
||||||
|
# Discover layers to count the number and experts for state sizing
|
||||||
|
adapters: list[BaseMoEAdapter] = [
|
||||||
|
MixtralAdapter(),
|
||||||
|
Qwen3Adapter(),
|
||||||
|
BailingAdapter(),
|
||||||
|
Llama4Adapter(),
|
||||||
|
]
|
||||||
|
|
||||||
|
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||||
|
n_layers = 0
|
||||||
|
n_experts = None
|
||||||
|
for m in model.modules():
|
||||||
|
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||||
|
device = next(model.parameters(), torch.tensor(0)).device
|
||||||
|
if n_layers <= 0:
|
||||||
|
n_layers = 1
|
||||||
|
if n_experts is None:
|
||||||
|
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||||
|
n_experts = 2
|
||||||
|
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
||||||
|
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||||
|
ep_group = None
|
||||||
|
if sync_group == "ep":
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
ep_group = self._resolve_ep_group(cfg)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
||||||
|
)
|
||||||
|
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
||||||
|
|
||||||
|
# Discover and prepare layers (attach per-layer buffers)
|
||||||
|
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
|
||||||
|
return None
|
||||||
|
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||||
|
if not ep_size or ep_size <= 1:
|
||||||
|
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
|
||||||
|
return None
|
||||||
|
world = dist.get_world_size()
|
||||||
|
if world % ep_size != 0:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
||||||
|
ep_size,
|
||||||
|
world,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if ep_size == world:
|
||||||
|
return dist.group.WORLD
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
group_start = (rank // ep_size) * ep_size
|
||||||
|
ranks = tuple(range(group_start, group_start + ep_size))
|
||||||
|
if ranks not in self._ep_group_cache:
|
||||||
|
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||||
|
return self._ep_group_cache[ranks]
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||||
|
return []
|
||||||
|
if self._shim is None:
|
||||||
|
return []
|
||||||
|
# gather concrete layer modules from handles
|
||||||
|
layers = [h.layer for h in self._handles]
|
||||||
|
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
|
||||||
|
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||||
|
return [cb]
|
||||||
@@ -758,6 +758,44 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
llama4_linearized_experts: bool | None = None
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
|
# MoE aux-loss-free (AFB) toggles
|
||||||
|
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_update_rate: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. If unset, plugin default is 0.01.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_update_momentum: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "EMA momentum for expert load smoothing (0–1). If unset, plugin default is 0.9.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_bias_cap: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_afb_warmup_steps: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = Field(
|
deepspeed: str | dict[str, Any] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -836,6 +874,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
expert_parallel_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
||||||
|
},
|
||||||
|
)
|
||||||
special_tokens: SpecialTokensConfig | None = Field(
|
special_tokens: SpecialTokensConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -1386,6 +1386,14 @@ class ComplexValidationMixin:
|
|||||||
self.tensor_parallel_size = 1
|
self.tensor_parallel_size = 1
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_expert_parallel_size(self):
|
||||||
|
if not getattr(self, "expert_parallel_size", None):
|
||||||
|
self.expert_parallel_size = 1
|
||||||
|
elif self.expert_parallel_size < 1:
|
||||||
|
raise ValueError("expert_parallel_size must be >= 1")
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_context_parallel_size(self):
|
def check_context_parallel_size(self):
|
||||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||||
|
|||||||
79
tests/e2e/test_moe_aux_free.py
Normal file
79
tests/e2e/test_moe_aux_free.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoeAuxFree(unittest.TestCase):
|
||||||
|
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_mixtral_aux_free_smoke(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
# Aux-free plugin and toggles
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
],
|
||||||
|
"moe_balance_type": "noaux_tc",
|
||||||
|
"moe_update_rate": 0.01,
|
||||||
|
"moe_update_momentum": 0.9,
|
||||||
|
"moe_bias_cap": 2.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
# Inspect model modules for a patched MoE layer
|
||||||
|
patched = None
|
||||||
|
for m in model.modules():
|
||||||
|
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
|
||||||
|
patched = m
|
||||||
|
break
|
||||||
|
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||||
|
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
|
||||||
|
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
|
||||||
|
# ensure counts buffer got reset by callback (best effort)
|
||||||
|
assert torch.all(patched._afb_counts == 0)
|
||||||
|
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
83
tests/e2e/test_moe_aux_parity.py
Normal file
83
tests/e2e/test_moe_aux_parity.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
|
||||||
|
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _last_logged_loss(trainer) -> float | None:
|
||||||
|
# Scan log_history for the most recent entry with a 'loss' key
|
||||||
|
for entry in reversed(trainer.state.log_history):
|
||||||
|
if isinstance(entry, dict) and "loss" in entry:
|
||||||
|
return float(entry["loss"])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoeAuxParity(unittest.TestCase):
|
||||||
|
@with_temp_dir
|
||||||
|
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
|
||||||
|
base_cfg = {
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 8,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
"seed": 42,
|
||||||
|
"logging_steps": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Baseline: aux-loss (gshard)
|
||||||
|
cfg0 = DictDefault(dict(base_cfg))
|
||||||
|
cfg0.output_dir = f"{temp_dir}/baseline"
|
||||||
|
cfg0 = validate_config(cfg0)
|
||||||
|
normalize_config(cfg0)
|
||||||
|
# baseline uses default aux-loss routing; no plugin registration
|
||||||
|
dataset_meta0 = load_datasets(cfg=cfg0)
|
||||||
|
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
|
||||||
|
loss0 = _last_logged_loss(trainer0)
|
||||||
|
assert loss0 is not None
|
||||||
|
|
||||||
|
# Aux-free: plugin + noaux_tc
|
||||||
|
cfg1 = DictDefault(dict(base_cfg))
|
||||||
|
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||||
|
cfg1.plugins = [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
]
|
||||||
|
cfg1.moe_balance_type = "noaux_tc"
|
||||||
|
cfg1.moe_update_rate = 0.01
|
||||||
|
cfg1.moe_update_momentum = 0.9
|
||||||
|
cfg1.moe_bias_cap = 2.0
|
||||||
|
cfg1 = validate_config(cfg1)
|
||||||
|
normalize_config(cfg1)
|
||||||
|
prepare_plugins(cfg1)
|
||||||
|
dataset_meta1 = load_datasets(cfg=cfg1)
|
||||||
|
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
|
||||||
|
loss1 = _last_logged_loss(trainer1)
|
||||||
|
assert loss1 is not None
|
||||||
|
|
||||||
|
# Assert aux-free loss is within 10% of aux-loss baseline
|
||||||
|
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"
|
||||||
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3MoeAuxFree(unittest.TestCase):
|
||||||
|
@with_temp_dir
|
||||||
|
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||||
|
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
# Aux-free plugin and toggles
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
],
|
||||||
|
"moe_balance_type": "noaux_tc",
|
||||||
|
"moe_update_rate": 0.01,
|
||||||
|
"moe_update_momentum": 0.9,
|
||||||
|
"moe_bias_cap": 2.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
# check that at least one sparse MoE block has been patched
|
||||||
|
found = False
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
|
||||||
|
assert m._afb_patched is True
|
||||||
|
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||||
|
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
|
||||||
|
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -12,7 +12,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from tbparse import SummaryReader
|
try:
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
SummaryReader = None
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -185,6 +188,8 @@ def check_tensorboard(
|
|||||||
"""
|
"""
|
||||||
helper function to parse and check tensorboard logs
|
helper function to parse and check tensorboard logs
|
||||||
"""
|
"""
|
||||||
|
if SummaryReader is None:
|
||||||
|
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
|
||||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
|
|||||||
267
tests/unit/test_aux_free_adapters.py
Normal file
267
tests/unit/test_aux_free_adapters.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from importlib import util as importlib_util
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(**overrides):
|
||||||
|
defaults = dict(
|
||||||
|
moe_balance_type="noaux_tc",
|
||||||
|
moe_update_rate=0.1,
|
||||||
|
moe_update_momentum=0.9,
|
||||||
|
moe_bias_cap=2.0,
|
||||||
|
moe_afb_warmup_steps=0,
|
||||||
|
moe_bias_sync_group="world",
|
||||||
|
expert_parallel_size=1,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return SimpleNamespace(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_bailing_modules():
|
||||||
|
repo_dir = snapshot_download(
|
||||||
|
repo_id="inclusionAI/Ring-mini-2.0",
|
||||||
|
allow_patterns=[
|
||||||
|
"configuration_bailing_moe_v2.py",
|
||||||
|
"modeling_bailing_moe_v2.py",
|
||||||
|
"__init__.py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
repo = Path(repo_dir)
|
||||||
|
config_path = repo / "configuration_bailing_moe_v2.py"
|
||||||
|
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
||||||
|
|
||||||
|
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
||||||
|
if config_name not in sys.modules:
|
||||||
|
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
||||||
|
module = importlib_util.module_from_spec(spec)
|
||||||
|
sys.modules[config_name] = module
|
||||||
|
sys.modules["configuration_bailing_moe_v2"] = module
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
config_module = sys.modules[config_name]
|
||||||
|
|
||||||
|
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
||||||
|
if modeling_name not in sys.modules:
|
||||||
|
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
||||||
|
module = importlib_util.module_from_spec(spec)
|
||||||
|
sys.modules[modeling_name] = module
|
||||||
|
sys.modules["modeling_bailing_moe_v2"] = module
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
modeling_module = sys.modules[modeling_name]
|
||||||
|
|
||||||
|
BailingMoeV2Config = config_module.BailingMoeV2Config
|
||||||
|
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
||||||
|
|
||||||
|
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bailing_model():
|
||||||
|
BailingConfig, BailingBlock = _load_bailing_modules()
|
||||||
|
config = BailingConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
intermediate_size=32,
|
||||||
|
moe_intermediate_size=32,
|
||||||
|
num_experts=4,
|
||||||
|
num_shared_experts=None,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
n_group=1,
|
||||||
|
topk_group=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
)
|
||||||
|
block = BailingBlock(config)
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, layer):
|
||||||
|
super().__init__()
|
||||||
|
self.block = layer
|
||||||
|
self.config = SimpleNamespace(model_type="bailing_moe")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.block(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(block), block
|
||||||
|
|
||||||
|
|
||||||
|
def _build_llama4_model():
|
||||||
|
from transformers import Llama4TextConfig
|
||||||
|
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||||
|
|
||||||
|
config = Llama4TextConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
intermediate_size=32,
|
||||||
|
num_local_experts=4,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
)
|
||||||
|
layer = Llama4TextMoe(config)
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, moe_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe_layer
|
||||||
|
self.config = SimpleNamespace(model_type="llama4")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.moe(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mixtral_model():
|
||||||
|
from transformers import MixtralConfig
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
config = MixtralConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
intermediate_size=32,
|
||||||
|
num_local_experts=4,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
)
|
||||||
|
layer = MixtralSparseMoeBlock(config)
|
||||||
|
layer.config = config
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, moe_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe_layer
|
||||||
|
self.config = SimpleNamespace(model_type="mixtral")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.moe(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
|
def _run_callback(plugin, cfg):
|
||||||
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||||
|
assert callbacks, "expected aux-free callback to be registered"
|
||||||
|
callback = callbacks[0]
|
||||||
|
dummy = SimpleNamespace()
|
||||||
|
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxFreeAdapters(unittest.TestCase):
|
||||||
|
def test_bailing_adapter_updates_counts_and_bias(self):
|
||||||
|
model, block = _build_bailing_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(block, "_afb_bias"))
|
||||||
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
|
block(hidden)
|
||||||
|
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
||||||
|
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||||
|
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
||||||
|
|
||||||
|
def test_llama4_adapter_biases_router_selection(self):
|
||||||
|
model, layer = _build_llama4_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||||
|
hidden = torch.randn(2, 4, layer.hidden_dim)
|
||||||
|
layer(hidden)
|
||||||
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
||||||
|
|
||||||
|
def test_bias_warmup_respected(self):
|
||||||
|
model, block = _build_bailing_model()
|
||||||
|
cfg = _cfg(moe_afb_warmup_steps=2)
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||||
|
self.assertTrue(callbacks)
|
||||||
|
callback = callbacks[0]
|
||||||
|
dummy = SimpleNamespace()
|
||||||
|
|
||||||
|
def _step():
|
||||||
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
|
block(hidden)
|
||||||
|
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||||
|
|
||||||
|
# Warmup steps should leave bias untouched.
|
||||||
|
_step()
|
||||||
|
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||||
|
|
||||||
|
_step()
|
||||||
|
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||||
|
|
||||||
|
# Third step exceeds warmup -> bias should update.
|
||||||
|
_step()
|
||||||
|
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||||
|
|
||||||
|
def test_mixtral_adapter_respects_native_forward(self):
|
||||||
|
model, layer = _build_mixtral_model()
|
||||||
|
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
||||||
|
|
||||||
|
hidden_dim = layer.config.hidden_size
|
||||||
|
hidden = torch.randn(2, 3, hidden_dim)
|
||||||
|
baseline_out, baseline_logits = layer(hidden.clone())
|
||||||
|
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
patched_out, patched_logits = layer(hidden.clone())
|
||||||
|
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
||||||
|
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
||||||
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
|
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
model, block = _build_bailing_model()
|
||||||
|
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
self.assertIsNotNone(plugin._shim)
|
||||||
|
self.assertIsNone(plugin._shim.ep_group)
|
||||||
|
|
||||||
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||||
|
self.assertTrue(callbacks)
|
||||||
|
callback = callbacks[0]
|
||||||
|
dummy = SimpleNamespace()
|
||||||
|
|
||||||
|
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
tmp_init.close()
|
||||||
|
init_method = f"file://{tmp_init.name}"
|
||||||
|
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
||||||
|
try:
|
||||||
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
|
block(hidden)
|
||||||
|
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||||
|
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||||
|
finally:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
os.unlink(tmp_init.name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user