update for transformers v5 for experts parameters and compose with moe kernels
This commit is contained in:
@@ -1,11 +1,19 @@
|
||||
"""Architecture-specific adapters for aux-loss-free MoE routing.
|
||||
|
||||
Each adapter discovers MoE layers for a model family and patches only the
|
||||
router/gate to inject per-expert bias into expert selection while keeping
|
||||
mixture weights from unbiased logits. Expert dispatch is left untouched so
|
||||
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -23,9 +31,10 @@ class LayerHandle:
|
||||
|
||||
|
||||
class BaseMoEAdapter:
|
||||
"""Base adapter that discovers MoE layers and wraps their forward.
|
||||
"""Base adapter that discovers MoE layers and patches their routing.
|
||||
|
||||
Concrete adapters should implement discovery and per-layer attribute extraction.
|
||||
Concrete adapters implement discovery, attribute extraction, and
|
||||
architecture-specific router patching.
|
||||
"""
|
||||
|
||||
family: str = "generic"
|
||||
@@ -33,158 +42,191 @@ class BaseMoEAdapter:
|
||||
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||
return False
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
|
||||
def find_moe_layers(
|
||||
self, model: nn.Module
|
||||
) -> Iterable[nn.Module]: # pragma: no cover
|
||||
return []
|
||||
|
||||
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
||||
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
|
||||
def get_top_k(self, moe_layer: nn.Module) -> int:
|
||||
"""Resolve top_k from the MoE layer, checking common attribute paths."""
|
||||
for attr_path in [
|
||||
("top_k",),
|
||||
("num_experts_per_tok",),
|
||||
("gate", "top_k"),
|
||||
("router", "top_k"),
|
||||
]:
|
||||
obj: object = moe_layer
|
||||
for attr in attr_path:
|
||||
obj = getattr(obj, attr, None)
|
||||
if obj is None:
|
||||
break
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
return 2
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
||||
return int(getattr(moe_layer, "num_experts"))
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
|
||||
for attr_path in [
|
||||
("num_experts",),
|
||||
("num_local_experts",),
|
||||
("gate", "num_experts"),
|
||||
("router", "num_experts"),
|
||||
("experts", "num_experts"),
|
||||
]:
|
||||
obj: object = moe_layer
|
||||
for attr in attr_path:
|
||||
obj = getattr(obj, attr, None)
|
||||
if obj is None:
|
||||
break
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||
|
||||
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||
# Best-effort: zero router aux loss coef if present
|
||||
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||
try:
|
||||
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
|
||||
model_or_layer.router_aux_loss_coef = 0.0
|
||||
except Exception: # pragma: no cover - non-critical
|
||||
pass
|
||||
|
||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
def _register_aux_buffers(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
||||
if not hasattr(moe_layer, "_afb_bias"):
|
||||
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
||||
moe_layer.register_buffer(
|
||||
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_counts"):
|
||||
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
|
||||
moe_layer.register_buffer(
|
||||
"_afb_counts", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_ema"):
|
||||
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
||||
moe_layer.register_buffer(
|
||||
"_afb_ema", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
"""Attach per-layer buffers and mark as aux-free enabled."""
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
"""Attach per-layer buffers. Subclasses override to also patch routing."""
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_forward_with_aux_free(moe_layer)
|
||||
|
||||
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
||||
"""Replace the layer's forward with an aux-free gating version.
|
||||
|
||||
Assumes the layer exposes attributes:
|
||||
- gate: linear router projecting hidden to num_experts
|
||||
- num_experts: int
|
||||
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
|
||||
"""
|
||||
if getattr(moe_layer, "_afb_patched", False):
|
||||
return
|
||||
|
||||
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
|
||||
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
|
||||
return
|
||||
|
||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||
# hidden_states: (B, T, H)
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
hs = hidden_states.view(-1, hdim)
|
||||
logits = self.gate(hs)
|
||||
# selection uses biased logits; weights from unbiased logits
|
||||
bias = getattr(self, "_afb_bias")
|
||||
top_k = int(getattr(self, "_afb_top_k", 2))
|
||||
biased = logits + bias # broadcast over tokens
|
||||
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
|
||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||
weights = torch.softmax(chosen_logits.float(), dim=-1)
|
||||
weights = weights.to(hs.dtype)
|
||||
|
||||
# accumulate counts for bias update callback
|
||||
flat_idx = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
|
||||
|
||||
# dispatch tokens to experts
|
||||
hs_rep = hs.repeat_interleave(top_k, dim=0)
|
||||
y = torch.empty_like(hs_rep)
|
||||
for eid in range(int(self.num_experts)):
|
||||
mask = flat_idx == eid
|
||||
if mask.any():
|
||||
y[mask] = self.experts[eid](hs_rep[mask])
|
||||
|
||||
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
|
||||
out = y.view(bsz, seqlen, hdim)
|
||||
return (out, logits)
|
||||
|
||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||
setattr(moe_layer, "_afb_patched", True)
|
||||
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
|
||||
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
|
||||
already replaced the block forward, meaning the routing is handled
|
||||
inside the kernel forward and we should NOT patch the router."""
|
||||
cls = type(moe_layer)
|
||||
# SonicMoE stores the original forward when it patches a class.
|
||||
if hasattr(cls, "_original_forward"):
|
||||
return True
|
||||
# ScatterMoE replaces via kernels library; check for the marker.
|
||||
if hasattr(cls, "_kernel_forward"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MixtralAdapter(BaseMoEAdapter):
|
||||
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
|
||||
routing so that biased logits drive expert *selection* while unbiased
|
||||
softmax scores drive mixture *weights*.
|
||||
|
||||
Works with transformers v5 where experts are fused 3D tensors and
|
||||
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
|
||||
"""
|
||||
|
||||
family = "mixtral"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_mixtral_forward(moe_layer, shim)
|
||||
return (
|
||||
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||
)
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||
yield m
|
||||
|
||||
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
|
||||
if getattr(moe_layer, "_afb_patched", False):
|
||||
return
|
||||
|
||||
shim_ref = shim
|
||||
|
||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
if self.training and getattr(self, "jitter_noise", 0) > 0:
|
||||
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
||||
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
||||
)
|
||||
flat_states = hidden_states.view(-1, hidden_dim)
|
||||
router_logits = self.gate(flat_states)
|
||||
|
||||
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
|
||||
top_k = int(getattr(self, "_afb_top_k", self.top_k))
|
||||
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
|
||||
routing_weights = routing_weights.to(flat_states.dtype)
|
||||
|
||||
flat_idx = selected_experts.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=flat_states.dtype,
|
||||
device=flat_states.device,
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
if not self.uses_kernel_routing(moe_layer):
|
||||
self._patch_router(moe_layer)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: kernel backend detected on %s; "
|
||||
"skipping router patch (kernel routing handles bias)",
|
||||
type(moe_layer).__name__,
|
||||
)
|
||||
|
||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx_tensor in expert_hit:
|
||||
expert_idx = int(expert_idx_tensor.squeeze().item())
|
||||
expert_layer = self.experts[expert_idx]
|
||||
mask = expert_mask[expert_idx].squeeze(0)
|
||||
idx, top_x = torch.where(mask)
|
||||
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
||||
def _patch_router(self, moe_layer: nn.Module) -> None:
|
||||
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
# Capture reference to the MoE block for bias / counts access.
|
||||
block_ref = moe_layer
|
||||
|
||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||
setattr(moe_layer, "_afb_patched", True)
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_logits = F.linear(hidden_states, self.weight)
|
||||
router_probs = F.softmax(router_logits.float(), dim=-1)
|
||||
|
||||
# Biased selection, unbiased weights
|
||||
bias = block_ref._afb_bias
|
||||
biased = router_probs + bias
|
||||
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
|
||||
router_scores = torch.gather(router_probs, 1, router_indices)
|
||||
|
||||
# Renormalize (Mixtral always normalizes; Qwen checks config)
|
||||
if getattr(self, "norm_topk_prob", True):
|
||||
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Accumulate counts for the bias-update callback
|
||||
flat_idx = router_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=self.num_experts)
|
||||
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
|
||||
|
||||
return router_probs, router_scores, router_indices
|
||||
|
||||
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
gate._afb_patched = True
|
||||
moe_layer._afb_patched = True
|
||||
|
||||
|
||||
class Qwen3Adapter(MixtralAdapter):
|
||||
family = "qwen3_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"qwen3_moe",
|
||||
"qwen2_moe",
|
||||
)
|
||||
|
||||
|
||||
class Qwen35MoeAdapter(MixtralAdapter):
|
||||
"""Adapter for Qwen 3.5 MoE models.
|
||||
|
||||
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
|
||||
is handled by the block forward (untouched by router-level patching).
|
||||
"""
|
||||
|
||||
family = "qwen3_5_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"qwen3_5_moe",
|
||||
"qwen3_5_moe_text",
|
||||
)
|
||||
|
||||
|
||||
class BailingAdapter(BaseMoEAdapter):
|
||||
@@ -207,11 +249,15 @@ class BailingAdapter(BaseMoEAdapter):
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
if hasattr(moe_layer, "num_experts"):
|
||||
return int(getattr(moe_layer, "num_experts"))
|
||||
return int(moe_layer.num_experts)
|
||||
cfg = getattr(moe_layer, "config", None)
|
||||
return int(getattr(cfg, "num_experts"))
|
||||
if cfg is None:
|
||||
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||
return int(cfg.num_experts)
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_bailing_gate(moe_layer)
|
||||
|
||||
@@ -227,7 +273,7 @@ class BailingAdapter(BaseMoEAdapter):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
logits = F.linear(flat.float(), self.weight.float())
|
||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
bias = getattr(moe_layer, "_afb_bias")
|
||||
bias = moe_layer._afb_bias
|
||||
biased_scores = scores_unbiased + bias
|
||||
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||
@@ -238,12 +284,12 @@ class BailingAdapter(BaseMoEAdapter):
|
||||
|
||||
flat_topk = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||
|
||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
setattr(gate, "_afb_patched", True)
|
||||
gate._afb_patched = True
|
||||
|
||||
|
||||
class Llama4Adapter(BaseMoEAdapter):
|
||||
@@ -257,7 +303,9 @@ class Llama4Adapter(BaseMoEAdapter):
|
||||
if m.__class__.__name__ == "Llama4TextMoe":
|
||||
yield m
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_llama4_router(moe_layer)
|
||||
|
||||
@@ -270,9 +318,13 @@ class Llama4Adapter(BaseMoEAdapter):
|
||||
return
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat = (
|
||||
hidden_states
|
||||
if hidden_states.dim() == 2
|
||||
else hidden_states.view(-1, hidden_states.shape[-1])
|
||||
)
|
||||
router_logits = F.linear(flat, self.weight, self.bias)
|
||||
bias = getattr(moe_layer, "_afb_bias")
|
||||
bias = moe_layer._afb_bias
|
||||
biased_logits = router_logits + bias
|
||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||
@@ -281,15 +333,17 @@ class Llama4Adapter(BaseMoEAdapter):
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||
|
||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return router_scores, router_logits
|
||||
|
||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||
setattr(router, "_afb_patched", True)
|
||||
router._afb_patched = True
|
||||
|
||||
|
||||
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
|
||||
def discover_and_prepare_layers(
|
||||
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
|
||||
) -> list[LayerHandle]:
|
||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||
|
||||
Returns a list of layer handles for later routing patching and updates.
|
||||
@@ -313,7 +367,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
|
||||
try:
|
||||
top_k = adapter.get_top_k(layer)
|
||||
nE = adapter.get_num_experts(layer)
|
||||
except Exception:
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
continue
|
||||
|
||||
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||
@@ -321,5 +375,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
|
||||
handles.append(handle)
|
||||
idx += 1
|
||||
|
||||
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
|
||||
LOG.info(
|
||||
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing"
|
||||
)
|
||||
return handles
|
||||
|
||||
@@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -21,6 +21,7 @@ from .adapters import (
|
||||
Llama4Adapter,
|
||||
MixtralAdapter,
|
||||
Qwen3Adapter,
|
||||
Qwen35MoeAdapter,
|
||||
discover_and_prepare_layers,
|
||||
)
|
||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||
@@ -47,9 +48,11 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
# Iterate prepared MoE layers and apply the bias update rule.
|
||||
self.shim.begin_step()
|
||||
for layer in self.layer_modules:
|
||||
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
|
||||
if not hasattr(layer, "_afb_counts") or not hasattr(
|
||||
layer, "_afb_layer_idx"
|
||||
):
|
||||
continue
|
||||
counts = getattr(layer, "_afb_counts")
|
||||
counts = layer._afb_counts
|
||||
if counts is None:
|
||||
continue
|
||||
counts = self.shim.all_reduce_counts(counts)
|
||||
@@ -57,7 +60,7 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
if layer_idx is None:
|
||||
counts.zero_()
|
||||
continue
|
||||
bias = getattr(layer, "_afb_bias")
|
||||
bias = layer._afb_bias
|
||||
counts_for_update = counts.to(bias.device)
|
||||
tokens_seen = int(counts_for_update.sum().item())
|
||||
# local layer-state EMA and bias update
|
||||
@@ -143,12 +146,16 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
return
|
||||
|
||||
# Be conservative — skip known native aux-free families
|
||||
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
native_auxfree = getattr(
|
||||
getattr(model, "config", object()), "model_type", ""
|
||||
) in (
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
)
|
||||
if native_auxfree:
|
||||
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
|
||||
LOG.info(
|
||||
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
|
||||
)
|
||||
return
|
||||
|
||||
# Build aux-free state and shim
|
||||
@@ -171,6 +178,7 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
adapters: list[BaseMoEAdapter] = [
|
||||
MixtralAdapter(),
|
||||
Qwen3Adapter(),
|
||||
Qwen35MoeAdapter(),
|
||||
BailingAdapter(),
|
||||
Llama4Adapter(),
|
||||
]
|
||||
@@ -178,7 +186,7 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||
n_layers = 0
|
||||
n_experts = None
|
||||
for m in model.modules():
|
||||
for _m in model.modules():
|
||||
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||
device = next(model.parameters(), torch.tensor(0)).device
|
||||
if n_layers <= 0:
|
||||
@@ -186,7 +194,9 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
if n_experts is None:
|
||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||
n_experts = 2
|
||||
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
||||
state = AuxFreeState(
|
||||
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
|
||||
)
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
ep_group = None
|
||||
if sync_group == "ep":
|
||||
@@ -207,11 +217,15 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
|
||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
|
||||
)
|
||||
return None
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
|
||||
)
|
||||
return None
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
|
||||
@@ -240,7 +240,16 @@ def _softmax_topk_route(
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
# Aux-free bias: biased selection, unbiased weights
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = routing_weights + afb_bias
|
||||
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
|
||||
routing_weights = routing_weights.gather(1, selected_experts)
|
||||
_accumulate_afb_counts(moe_block, selected_experts)
|
||||
else:
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
|
||||
else:
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = scores_for_choice + afb_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
if n_group > 1:
|
||||
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
if getattr(moe_block, "norm_topk_prob", True):
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||
"""Accumulate per-expert token counts for aux-free bias updates."""
|
||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||
if afb_counts is None:
|
||||
return
|
||||
flat_idx = topk_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
|
||||
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared expert helpers
|
||||
# =============================================================================
|
||||
|
||||
@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
|
||||
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
|
||||
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
|
||||
as buffers. The routing functions transparently inject the bias into expert
|
||||
*selection* (biased topk) while keeping mixture *weights* from unbiased
|
||||
scores, then accumulate per-expert token counts for the post-step bias update.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -101,17 +107,25 @@ def softmax_topk_routing(
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Aux-free bias: biased selection, unbiased weights
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
scores_for_choice = router_probs
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = router_probs + afb_bias
|
||||
|
||||
# Select top-k experts per token
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
|
||||
|
||||
# When aux-free bias is active, gather unbiased weights and accumulate counts
|
||||
if afb_bias is not None:
|
||||
top_values = router_probs.gather(1, top_indices)
|
||||
_accumulate_afb_counts(moe_block, top_indices)
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
# top_values = top_values.to(router_probs.dtype)
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Aux-free bias: inject before group selection / topk
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
scores_for_choice = router_probs
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = router_probs + afb_bias
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
@@ -164,6 +182,10 @@ def softmax_group_topk_routing(
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
@@ -233,6 +255,11 @@ def sigmoid_topk_routing(
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = scores_for_choice + afb_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
@@ -256,6 +283,10 @@ def sigmoid_topk_routing(
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
@@ -276,3 +307,19 @@ def sigmoid_topk_routing(
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||
"""Accumulate per-expert token counts for the aux-free bias update.
|
||||
|
||||
Called when ``moe_block._afb_bias`` is present (registered by the
|
||||
``aux_free_router`` plugin). The counts are later consumed by the
|
||||
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
||||
"""
|
||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||
if afb_counts is None:
|
||||
return
|
||||
num_experts = afb_counts.numel()
|
||||
flat_idx = topk_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=num_experts)
|
||||
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||
|
||||
@@ -2,14 +2,13 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||
@@ -96,16 +95,21 @@ def _build_bailing_model():
|
||||
|
||||
|
||||
def _build_llama4_model():
|
||||
from transformers import Llama4TextConfig
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
config = Llama4TextConfig(
|
||||
# Build config without __post_init__ validation (works around a
|
||||
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
||||
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
||||
config.__dict__.update(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_experts_per_tok=2,
|
||||
num_hidden_layers=2,
|
||||
hidden_act="silu",
|
||||
layer_types=None,
|
||||
)
|
||||
layer = Llama4TextMoe(config)
|
||||
|
||||
@@ -148,6 +152,38 @@ def _build_mixtral_model():
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_qwen35_moe_model():
|
||||
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
||||
Qwen3_5MoeTextConfig,
|
||||
)
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeSparseMoeBlock,
|
||||
)
|
||||
|
||||
config = Qwen3_5MoeTextConfig(
|
||||
hidden_size=16,
|
||||
moe_intermediate_size=32,
|
||||
shared_expert_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
layer = Qwen3_5MoeSparseMoeBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
@@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
||||
self.assertFalse(
|
||||
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
||||
)
|
||||
|
||||
def test_llama4_adapter_biases_router_selection(self):
|
||||
model, layer = _build_llama4_model()
|
||||
@@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_bias_warmup_respected(self):
|
||||
model, block = _build_bailing_model()
|
||||
@@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
# Third step exceeds warmup -> bias should update.
|
||||
_step()
|
||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||
|
||||
def test_mixtral_adapter_respects_native_forward(self):
|
||||
def test_mixtral_adapter_patches_router_not_forward(self):
|
||||
"""Verify that aux-free patches the router (gate) only, and the
|
||||
v5 block forward signature (single tensor return) is preserved."""
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
baseline_out, baseline_logits = layer(hidden.clone())
|
||||
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
patched_out, patched_logits = layer(hidden.clone())
|
||||
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
||||
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
||||
# Gate should be patched, not the block forward
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
|
||||
# v5 block forward returns a single tensor (not a tuple with logits)
|
||||
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
def test_mixtral_adapter_bias_affects_selection(self):
|
||||
"""When bias is large for one expert, it should be selected more often."""
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Set a large bias for expert 0 to force its selection
|
||||
layer._afb_bias.zero_()
|
||||
layer._afb_bias[0] = 10.0
|
||||
|
||||
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
||||
num_tokens = 2 * 8 # batch * seq
|
||||
layer(hidden)
|
||||
|
||||
# With top_k=2, expert 0 should appear in every token's selection
|
||||
# (once per token = num_tokens counts, not num_tokens * top_k)
|
||||
counts = layer._afb_counts.clone()
|
||||
self.assertEqual(
|
||||
int(counts[0].item()),
|
||||
num_tokens,
|
||||
msg="Expert 0 should be selected for every token when heavily biased",
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
||||
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
||||
output includes shared expert contribution."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Gate should be patched
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
# Shared expert should be unmodified
|
||||
self.assertTrue(hasattr(layer, "shared_expert"))
|
||||
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
||||
|
||||
# Forward should return a single tensor (shared + routed)
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 3, hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
def test_qwen35_moe_adapter_bias_updates(self):
|
||||
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 4, hidden_size)
|
||||
layer(hidden)
|
||||
|
||||
# Bias should start at zero
|
||||
self.assertTrue(
|
||||
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
||||
)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# After callback: counts reset, EMA updated, bias updated
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_model_type_matching(self):
|
||||
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
||||
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
||||
|
||||
adapter = Qwen35MoeAdapter()
|
||||
|
||||
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
||||
model_text = SimpleNamespace(
|
||||
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
||||
)
|
||||
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
||||
|
||||
self.assertTrue(adapter.matches(model_moe))
|
||||
self.assertTrue(adapter.matches(model_text))
|
||||
self.assertFalse(adapter.matches(model_other))
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
@@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
||||
dist.init_process_group(
|
||||
backend="gloo", init_method=init_method, world_size=1, rank=0
|
||||
)
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
@@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
@@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||
|
||||
def test_get_num_experts_v5_attribute_paths(self):
|
||||
"""Verify get_num_experts works with v5 attribute layout where
|
||||
num_experts is on gate/experts sub-modules, not the block."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
||||
block = SimpleNamespace(
|
||||
gate=SimpleNamespace(num_experts=8),
|
||||
experts=SimpleNamespace(num_experts=8),
|
||||
)
|
||||
self.assertEqual(adapter.get_num_experts(block), 8)
|
||||
|
||||
# Also works when num_experts is directly on block
|
||||
block2 = SimpleNamespace(num_experts=4)
|
||||
self.assertEqual(adapter.get_num_experts(block2), 4)
|
||||
|
||||
|
||||
class TestAuxFreeKernelComposition(unittest.TestCase):
|
||||
"""Tests that aux-free bias composes correctly with kernel routing."""
|
||||
|
||||
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
||||
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
# Build a mock MoE block with gate attributes
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline: no bias
|
||||
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
self.assertEqual(scores_base.shape[0], T * top_k)
|
||||
|
||||
# Now register aux-free buffers and set heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# Expert 0 should be selected for every token
|
||||
self.assertTrue(
|
||||
(exp_biased == 0).any(),
|
||||
"Expert 0 should appear in selections when heavily biased",
|
||||
)
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
# Total counts should equal T * top_k
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_sonicmoe_routing_without_bias_unchanged(self):
|
||||
"""Without _afb_bias, routing should produce identical results."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(4, hidden_dim)
|
||||
|
||||
# Without _afb_bias attribute
|
||||
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
# With _afb_bias = zeros (should be equivalent)
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
torch.testing.assert_close(scores1, scores2)
|
||||
torch.testing.assert_close(exp1, exp2)
|
||||
|
||||
@unittest.skipUnless(
|
||||
importlib_util.find_spec("triton") is not None,
|
||||
"triton not installed (required by scattermoe)",
|
||||
)
|
||||
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
||||
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
gate_weight = torch.randn(num_experts, hidden_dim)
|
||||
base_gate = SimpleNamespace(
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
norm_topk_prob=True,
|
||||
weight=gate_weight,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline without bias
|
||||
w_base, e_base, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# With heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
w_biased, e_biased, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# Expert 0 should appear in all selections
|
||||
self.assertTrue((e_biased == 0).any())
|
||||
# Counts accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_kernel_routing_skips_router_patch(self):
|
||||
"""When a kernel backend has patched the block class, the adapter
|
||||
should skip patching the router (buffers are still registered)."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True # SonicMoE marker
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16) # placeholder
|
||||
|
||||
layer = PatchedBlock()
|
||||
self.assertTrue(adapter.uses_kernel_routing(layer))
|
||||
|
||||
# Gate should NOT be patched (kernel handles routing)
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
def test_adapter_buffers_registered_even_with_kernel(self):
|
||||
"""Even when kernel routing is active, aux-free buffers must be
|
||||
registered on the MoE block so the kernel routing can find them."""
|
||||
from axolotl.integrations.aux_free_router.adapters import (
|
||||
LayerHandle,
|
||||
MixtralAdapter,
|
||||
)
|
||||
from axolotl.integrations.aux_free_router.core import (
|
||||
AuxFreeConfig,
|
||||
AuxFreeShim,
|
||||
AuxFreeState,
|
||||
)
|
||||
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16)
|
||||
|
||||
layer = PatchedBlock()
|
||||
adapter = MixtralAdapter()
|
||||
cfg = AuxFreeConfig()
|
||||
state = AuxFreeState(
|
||||
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
||||
)
|
||||
shim = AuxFreeShim(state=state)
|
||||
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
||||
|
||||
adapter.prepare(layer, handle, shim)
|
||||
|
||||
# Buffers should be registered for kernel routing to use
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
self.assertTrue(hasattr(layer, "_afb_counts"))
|
||||
self.assertTrue(hasattr(layer, "_afb_ema"))
|
||||
# But gate should NOT be patched
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user