update for transformers v5 for experts parameters and compose with moe kernels
This commit is contained in:
@@ -1,11 +1,19 @@
|
|||||||
|
"""Architecture-specific adapters for aux-loss-free MoE routing.
|
||||||
|
|
||||||
|
Each adapter discovers MoE layers for a model family and patches only the
|
||||||
|
router/gate to inject per-expert bias into expert selection while keeping
|
||||||
|
mixture weights from unbiased logits. Expert dispatch is left untouched so
|
||||||
|
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -23,9 +31,10 @@ class LayerHandle:
|
|||||||
|
|
||||||
|
|
||||||
class BaseMoEAdapter:
|
class BaseMoEAdapter:
|
||||||
"""Base adapter that discovers MoE layers and wraps their forward.
|
"""Base adapter that discovers MoE layers and patches their routing.
|
||||||
|
|
||||||
Concrete adapters should implement discovery and per-layer attribute extraction.
|
Concrete adapters implement discovery, attribute extraction, and
|
||||||
|
architecture-specific router patching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
family: str = "generic"
|
family: str = "generic"
|
||||||
@@ -33,158 +42,191 @@ class BaseMoEAdapter:
|
|||||||
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
|
def find_moe_layers(
|
||||||
|
self, model: nn.Module
|
||||||
|
) -> Iterable[nn.Module]: # pragma: no cover
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
def get_top_k(self, moe_layer: nn.Module) -> int:
|
||||||
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
|
"""Resolve top_k from the MoE layer, checking common attribute paths."""
|
||||||
|
for attr_path in [
|
||||||
|
("top_k",),
|
||||||
|
("num_experts_per_tok",),
|
||||||
|
("gate", "top_k"),
|
||||||
|
("router", "top_k"),
|
||||||
|
]:
|
||||||
|
obj: object = moe_layer
|
||||||
|
for attr in attr_path:
|
||||||
|
obj = getattr(obj, attr, None)
|
||||||
|
if obj is None:
|
||||||
|
break
|
||||||
|
if isinstance(obj, int):
|
||||||
|
return obj
|
||||||
|
return 2
|
||||||
|
|
||||||
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||||
return int(getattr(moe_layer, "num_experts"))
|
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
|
||||||
|
for attr_path in [
|
||||||
|
("num_experts",),
|
||||||
|
("num_local_experts",),
|
||||||
|
("gate", "num_experts"),
|
||||||
|
("router", "num_experts"),
|
||||||
|
("experts", "num_experts"),
|
||||||
|
]:
|
||||||
|
obj: object = moe_layer
|
||||||
|
for attr in attr_path:
|
||||||
|
obj = getattr(obj, attr, None)
|
||||||
|
if obj is None:
|
||||||
|
break
|
||||||
|
if isinstance(obj, int):
|
||||||
|
return obj
|
||||||
|
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||||
|
|
||||||
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||||
# Best-effort: zero router aux loss coef if present
|
# Best-effort: zero router aux loss coef if present
|
||||||
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||||
try:
|
try:
|
||||||
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
|
model_or_layer.router_aux_loss_coef = 0.0
|
||||||
except Exception: # pragma: no cover - non-critical
|
except Exception: # pragma: no cover - non-critical
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
def _register_aux_buffers(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
||||||
if not hasattr(moe_layer, "_afb_bias"):
|
if not hasattr(moe_layer, "_afb_bias"):
|
||||||
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
moe_layer.register_buffer(
|
||||||
|
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
||||||
|
)
|
||||||
if not hasattr(moe_layer, "_afb_counts"):
|
if not hasattr(moe_layer, "_afb_counts"):
|
||||||
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
|
moe_layer.register_buffer(
|
||||||
|
"_afb_counts", torch.zeros(handle.num_experts, device=device)
|
||||||
|
)
|
||||||
if not hasattr(moe_layer, "_afb_ema"):
|
if not hasattr(moe_layer, "_afb_ema"):
|
||||||
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
moe_layer.register_buffer(
|
||||||
|
"_afb_ema", torch.zeros(handle.num_experts, device=device)
|
||||||
|
)
|
||||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
def prepare(
|
||||||
"""Attach per-layer buffers and mark as aux-free enabled."""
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
|
"""Attach per-layer buffers. Subclasses override to also patch routing."""
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
self._patch_forward_with_aux_free(moe_layer)
|
|
||||||
|
|
||||||
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
|
||||||
"""Replace the layer's forward with an aux-free gating version.
|
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
|
||||||
|
already replaced the block forward, meaning the routing is handled
|
||||||
Assumes the layer exposes attributes:
|
inside the kernel forward and we should NOT patch the router."""
|
||||||
- gate: linear router projecting hidden to num_experts
|
cls = type(moe_layer)
|
||||||
- num_experts: int
|
# SonicMoE stores the original forward when it patches a class.
|
||||||
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
|
if hasattr(cls, "_original_forward"):
|
||||||
"""
|
return True
|
||||||
if getattr(moe_layer, "_afb_patched", False):
|
# ScatterMoE replaces via kernels library; check for the marker.
|
||||||
return
|
if hasattr(cls, "_kernel_forward"):
|
||||||
|
return True
|
||||||
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
|
return False
|
||||||
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
|
|
||||||
return
|
|
||||||
|
|
||||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
|
||||||
# hidden_states: (B, T, H)
|
|
||||||
bsz, seqlen, hdim = hidden_states.shape
|
|
||||||
hs = hidden_states.view(-1, hdim)
|
|
||||||
logits = self.gate(hs)
|
|
||||||
# selection uses biased logits; weights from unbiased logits
|
|
||||||
bias = getattr(self, "_afb_bias")
|
|
||||||
top_k = int(getattr(self, "_afb_top_k", 2))
|
|
||||||
biased = logits + bias # broadcast over tokens
|
|
||||||
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
|
|
||||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
|
||||||
weights = torch.softmax(chosen_logits.float(), dim=-1)
|
|
||||||
weights = weights.to(hs.dtype)
|
|
||||||
|
|
||||||
# accumulate counts for bias update callback
|
|
||||||
flat_idx = topk_idx.reshape(-1)
|
|
||||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
|
||||||
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
|
|
||||||
|
|
||||||
# dispatch tokens to experts
|
|
||||||
hs_rep = hs.repeat_interleave(top_k, dim=0)
|
|
||||||
y = torch.empty_like(hs_rep)
|
|
||||||
for eid in range(int(self.num_experts)):
|
|
||||||
mask = flat_idx == eid
|
|
||||||
if mask.any():
|
|
||||||
y[mask] = self.experts[eid](hs_rep[mask])
|
|
||||||
|
|
||||||
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
|
|
||||||
out = y.view(bsz, seqlen, hdim)
|
|
||||||
return (out, logits)
|
|
||||||
|
|
||||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
|
||||||
setattr(moe_layer, "_afb_patched", True)
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralAdapter(BaseMoEAdapter):
|
class MixtralAdapter(BaseMoEAdapter):
|
||||||
|
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
|
||||||
|
routing so that biased logits drive expert *selection* while unbiased
|
||||||
|
softmax scores drive mixture *weights*.
|
||||||
|
|
||||||
|
Works with transformers v5 where experts are fused 3D tensors and
|
||||||
|
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
|
||||||
|
"""
|
||||||
|
|
||||||
family = "mixtral"
|
family = "mixtral"
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool:
|
def matches(self, model: nn.Module) -> bool:
|
||||||
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
return (
|
||||||
|
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
)
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
|
||||||
self._patch_mixtral_forward(moe_layer, shim)
|
|
||||||
|
|
||||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||||
yield m
|
yield m
|
||||||
|
|
||||||
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
|
def prepare(
|
||||||
if getattr(moe_layer, "_afb_patched", False):
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
return
|
) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
shim_ref = shim
|
if not self.uses_kernel_routing(moe_layer):
|
||||||
|
self._patch_router(moe_layer)
|
||||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
else:
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
LOG.info(
|
||||||
if self.training and getattr(self, "jitter_noise", 0) > 0:
|
"AuxFreeMoE: kernel backend detected on %s; "
|
||||||
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
"skipping router patch (kernel routing handles bias)",
|
||||||
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
type(moe_layer).__name__,
|
||||||
)
|
|
||||||
flat_states = hidden_states.view(-1, hidden_dim)
|
|
||||||
router_logits = self.gate(flat_states)
|
|
||||||
|
|
||||||
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
|
|
||||||
top_k = int(getattr(self, "_afb_top_k", self.top_k))
|
|
||||||
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
|
|
||||||
routing_weights = routing_weights.to(flat_states.dtype)
|
|
||||||
|
|
||||||
flat_idx = selected_experts.reshape(-1)
|
|
||||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
|
||||||
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
|
|
||||||
|
|
||||||
final_hidden_states = torch.zeros(
|
|
||||||
(batch_size * sequence_length, hidden_dim),
|
|
||||||
dtype=flat_states.dtype,
|
|
||||||
device=flat_states.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
def _patch_router(self, moe_layer: nn.Module) -> None:
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
|
||||||
for expert_idx_tensor in expert_hit:
|
gate = getattr(moe_layer, "gate", None)
|
||||||
expert_idx = int(expert_idx_tensor.squeeze().item())
|
if gate is None:
|
||||||
expert_layer = self.experts[expert_idx]
|
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
|
||||||
mask = expert_mask[expert_idx].squeeze(0)
|
return
|
||||||
idx, top_x = torch.where(mask)
|
if getattr(gate, "_afb_patched", False):
|
||||||
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
return
|
||||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
|
||||||
|
|
||||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
# Capture reference to the MoE block for bias / counts access.
|
||||||
return final_hidden_states, router_logits
|
block_ref = moe_layer
|
||||||
|
|
||||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||||
setattr(moe_layer, "_afb_patched", True)
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||||
|
router_logits = F.linear(hidden_states, self.weight)
|
||||||
|
router_probs = F.softmax(router_logits.float(), dim=-1)
|
||||||
|
|
||||||
|
# Biased selection, unbiased weights
|
||||||
|
bias = block_ref._afb_bias
|
||||||
|
biased = router_probs + bias
|
||||||
|
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
|
||||||
|
router_scores = torch.gather(router_probs, 1, router_indices)
|
||||||
|
|
||||||
|
# Renormalize (Mixtral always normalizes; Qwen checks config)
|
||||||
|
if getattr(self, "norm_topk_prob", True):
|
||||||
|
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Accumulate counts for the bias-update callback
|
||||||
|
flat_idx = router_indices.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=self.num_experts)
|
||||||
|
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
|
||||||
|
|
||||||
|
return router_probs, router_scores, router_indices
|
||||||
|
|
||||||
|
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||||
|
gate._afb_patched = True
|
||||||
|
moe_layer._afb_patched = True
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Adapter(MixtralAdapter):
|
class Qwen3Adapter(MixtralAdapter):
|
||||||
family = "qwen3_moe"
|
family = "qwen3_moe"
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool:
|
def matches(self, model: nn.Module) -> bool:
|
||||||
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
|
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||||
|
"qwen3_moe",
|
||||||
|
"qwen2_moe",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35MoeAdapter(MixtralAdapter):
|
||||||
|
"""Adapter for Qwen 3.5 MoE models.
|
||||||
|
|
||||||
|
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
|
||||||
|
is handled by the block forward (untouched by router-level patching).
|
||||||
|
"""
|
||||||
|
|
||||||
|
family = "qwen3_5_moe"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||||
|
"qwen3_5_moe",
|
||||||
|
"qwen3_5_moe_text",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BailingAdapter(BaseMoEAdapter):
|
class BailingAdapter(BaseMoEAdapter):
|
||||||
@@ -207,11 +249,15 @@ class BailingAdapter(BaseMoEAdapter):
|
|||||||
|
|
||||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||||
if hasattr(moe_layer, "num_experts"):
|
if hasattr(moe_layer, "num_experts"):
|
||||||
return int(getattr(moe_layer, "num_experts"))
|
return int(moe_layer.num_experts)
|
||||||
cfg = getattr(moe_layer, "config", None)
|
cfg = getattr(moe_layer, "config", None)
|
||||||
return int(getattr(cfg, "num_experts"))
|
if cfg is None:
|
||||||
|
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||||
|
return int(cfg.num_experts)
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
def prepare(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
self._patch_bailing_gate(moe_layer)
|
self._patch_bailing_gate(moe_layer)
|
||||||
|
|
||||||
@@ -227,7 +273,7 @@ class BailingAdapter(BaseMoEAdapter):
|
|||||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
logits = F.linear(flat.float(), self.weight.float())
|
logits = F.linear(flat.float(), self.weight.float())
|
||||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||||
bias = getattr(moe_layer, "_afb_bias")
|
bias = moe_layer._afb_bias
|
||||||
biased_scores = scores_unbiased + bias
|
biased_scores = scores_unbiased + bias
|
||||||
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
||||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||||
@@ -238,12 +284,12 @@ class BailingAdapter(BaseMoEAdapter):
|
|||||||
|
|
||||||
flat_topk = topk_idx.reshape(-1)
|
flat_topk = topk_idx.reshape(-1)
|
||||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||||
|
|
||||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||||
|
|
||||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||||
setattr(gate, "_afb_patched", True)
|
gate._afb_patched = True
|
||||||
|
|
||||||
|
|
||||||
class Llama4Adapter(BaseMoEAdapter):
|
class Llama4Adapter(BaseMoEAdapter):
|
||||||
@@ -257,7 +303,9 @@ class Llama4Adapter(BaseMoEAdapter):
|
|||||||
if m.__class__.__name__ == "Llama4TextMoe":
|
if m.__class__.__name__ == "Llama4TextMoe":
|
||||||
yield m
|
yield m
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
def prepare(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
self._patch_llama4_router(moe_layer)
|
self._patch_llama4_router(moe_layer)
|
||||||
|
|
||||||
@@ -270,9 +318,13 @@ class Llama4Adapter(BaseMoEAdapter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||||
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
|
flat = (
|
||||||
|
hidden_states
|
||||||
|
if hidden_states.dim() == 2
|
||||||
|
else hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
)
|
||||||
router_logits = F.linear(flat, self.weight, self.bias)
|
router_logits = F.linear(flat, self.weight, self.bias)
|
||||||
bias = getattr(moe_layer, "_afb_bias")
|
bias = moe_layer._afb_bias
|
||||||
biased_logits = router_logits + bias
|
biased_logits = router_logits + bias
|
||||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||||
@@ -281,15 +333,17 @@ class Llama4Adapter(BaseMoEAdapter):
|
|||||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||||
|
|
||||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||||
|
|
||||||
return router_scores, router_logits
|
return router_scores, router_logits
|
||||||
|
|
||||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||||
setattr(router, "_afb_patched", True)
|
router._afb_patched = True
|
||||||
|
|
||||||
|
|
||||||
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
|
def discover_and_prepare_layers(
|
||||||
|
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
|
||||||
|
) -> list[LayerHandle]:
|
||||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||||
|
|
||||||
Returns a list of layer handles for later routing patching and updates.
|
Returns a list of layer handles for later routing patching and updates.
|
||||||
@@ -313,7 +367,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
|
|||||||
try:
|
try:
|
||||||
top_k = adapter.get_top_k(layer)
|
top_k = adapter.get_top_k(layer)
|
||||||
nE = adapter.get_num_experts(layer)
|
nE = adapter.get_num_experts(layer)
|
||||||
except Exception:
|
except (AttributeError, TypeError, ValueError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||||
@@ -321,5 +375,7 @@ def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter]
|
|||||||
handles.append(handle)
|
handles.append(handle)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
|
LOG.info(
|
||||||
|
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing"
|
||||||
|
)
|
||||||
return handles
|
return handles
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -21,6 +21,7 @@ from .adapters import (
|
|||||||
Llama4Adapter,
|
Llama4Adapter,
|
||||||
MixtralAdapter,
|
MixtralAdapter,
|
||||||
Qwen3Adapter,
|
Qwen3Adapter,
|
||||||
|
Qwen35MoeAdapter,
|
||||||
discover_and_prepare_layers,
|
discover_and_prepare_layers,
|
||||||
)
|
)
|
||||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||||
@@ -47,9 +48,11 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
|||||||
# Iterate prepared MoE layers and apply the bias update rule.
|
# Iterate prepared MoE layers and apply the bias update rule.
|
||||||
self.shim.begin_step()
|
self.shim.begin_step()
|
||||||
for layer in self.layer_modules:
|
for layer in self.layer_modules:
|
||||||
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
|
if not hasattr(layer, "_afb_counts") or not hasattr(
|
||||||
|
layer, "_afb_layer_idx"
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
counts = getattr(layer, "_afb_counts")
|
counts = layer._afb_counts
|
||||||
if counts is None:
|
if counts is None:
|
||||||
continue
|
continue
|
||||||
counts = self.shim.all_reduce_counts(counts)
|
counts = self.shim.all_reduce_counts(counts)
|
||||||
@@ -57,7 +60,7 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
|||||||
if layer_idx is None:
|
if layer_idx is None:
|
||||||
counts.zero_()
|
counts.zero_()
|
||||||
continue
|
continue
|
||||||
bias = getattr(layer, "_afb_bias")
|
bias = layer._afb_bias
|
||||||
counts_for_update = counts.to(bias.device)
|
counts_for_update = counts.to(bias.device)
|
||||||
tokens_seen = int(counts_for_update.sum().item())
|
tokens_seen = int(counts_for_update.sum().item())
|
||||||
# local layer-state EMA and bias update
|
# local layer-state EMA and bias update
|
||||||
@@ -143,12 +146,16 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Be conservative — skip known native aux-free families
|
# Be conservative — skip known native aux-free families
|
||||||
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
|
native_auxfree = getattr(
|
||||||
|
getattr(model, "config", object()), "model_type", ""
|
||||||
|
) in (
|
||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
"glm4_moe",
|
"glm4_moe",
|
||||||
)
|
)
|
||||||
if native_auxfree:
|
if native_auxfree:
|
||||||
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
|
LOG.info(
|
||||||
|
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Build aux-free state and shim
|
# Build aux-free state and shim
|
||||||
@@ -171,6 +178,7 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
adapters: list[BaseMoEAdapter] = [
|
adapters: list[BaseMoEAdapter] = [
|
||||||
MixtralAdapter(),
|
MixtralAdapter(),
|
||||||
Qwen3Adapter(),
|
Qwen3Adapter(),
|
||||||
|
Qwen35MoeAdapter(),
|
||||||
BailingAdapter(),
|
BailingAdapter(),
|
||||||
Llama4Adapter(),
|
Llama4Adapter(),
|
||||||
]
|
]
|
||||||
@@ -178,7 +186,7 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||||
n_layers = 0
|
n_layers = 0
|
||||||
n_experts = None
|
n_experts = None
|
||||||
for m in model.modules():
|
for _m in model.modules():
|
||||||
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||||
device = next(model.parameters(), torch.tensor(0)).device
|
device = next(model.parameters(), torch.tensor(0)).device
|
||||||
if n_layers <= 0:
|
if n_layers <= 0:
|
||||||
@@ -186,7 +194,9 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
if n_experts is None:
|
if n_experts is None:
|
||||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||||
n_experts = 2
|
n_experts = 2
|
||||||
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
state = AuxFreeState(
|
||||||
|
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
|
||||||
|
)
|
||||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||||
ep_group = None
|
ep_group = None
|
||||||
if sync_group == "ep":
|
if sync_group == "ep":
|
||||||
@@ -207,11 +217,15 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
|
|
||||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||||
if not dist.is_available() or not dist.is_initialized():
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||||
if not ep_size or ep_size <= 1:
|
if not ep_size or ep_size <= 1:
|
||||||
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
world = dist.get_world_size()
|
world = dist.get_world_size()
|
||||||
if world % ep_size != 0:
|
if world % ep_size != 0:
|
||||||
|
|||||||
@@ -240,7 +240,16 @@ def _softmax_topk_route(
|
|||||||
|
|
||||||
top_k = base_gate.top_k
|
top_k = base_gate.top_k
|
||||||
num_experts = base_gate.num_experts
|
num_experts = base_gate.num_experts
|
||||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
||||||
|
# Aux-free bias: biased selection, unbiased weights
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = routing_weights + afb_bias
|
||||||
|
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
|
||||||
|
routing_weights = routing_weights.gather(1, selected_experts)
|
||||||
|
_accumulate_afb_counts(moe_block, selected_experts)
|
||||||
|
else:
|
||||||
|
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||||
|
|
||||||
if getattr(base_gate, "norm_topk_prob", True):
|
if getattr(base_gate, "norm_topk_prob", True):
|
||||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
|
|||||||
else:
|
else:
|
||||||
scores_for_choice = router_probs
|
scores_for_choice = router_probs
|
||||||
|
|
||||||
|
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = scores_for_choice + afb_bias
|
||||||
|
|
||||||
# Group-based selection: pick top groups, mask the rest
|
# Group-based selection: pick top groups, mask the rest
|
||||||
n_group = getattr(moe_block, "n_group", 1)
|
n_group = getattr(moe_block, "n_group", 1)
|
||||||
if n_group > 1:
|
if n_group > 1:
|
||||||
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
|
|||||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Accumulate counts for aux-free bias update
|
||||||
|
if afb_bias is not None:
|
||||||
|
_accumulate_afb_counts(moe_block, topk_indices)
|
||||||
|
|
||||||
# Optional renormalization + scaling
|
# Optional renormalization + scaling
|
||||||
if getattr(moe_block, "norm_topk_prob", True):
|
if getattr(moe_block, "norm_topk_prob", True):
|
||||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||||
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||||
|
"""Accumulate per-expert token counts for aux-free bias updates."""
|
||||||
|
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||||
|
if afb_counts is None:
|
||||||
|
return
|
||||||
|
flat_idx = topk_indices.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
|
||||||
|
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Shared expert helpers
|
# Shared expert helpers
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
|
|||||||
|
|
||||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||||
|
|
||||||
|
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
|
||||||
|
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
|
||||||
|
as buffers. The routing functions transparently inject the bias into expert
|
||||||
|
*selection* (biased topk) while keeping mixture *weights* from unbiased
|
||||||
|
scores, then accumulate per-expert token counts for the post-step bias update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -101,17 +107,25 @@ def softmax_topk_routing(
|
|||||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||||
|
|
||||||
|
# Aux-free bias: biased selection, unbiased weights
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
scores_for_choice = router_probs
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = router_probs + afb_bias
|
||||||
|
|
||||||
# Select top-k experts per token
|
# Select top-k experts per token
|
||||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
|
||||||
|
|
||||||
|
# When aux-free bias is active, gather unbiased weights and accumulate counts
|
||||||
|
if afb_bias is not None:
|
||||||
|
top_values = router_probs.gather(1, top_indices)
|
||||||
|
_accumulate_afb_counts(moe_block, top_indices)
|
||||||
|
|
||||||
# Renormalize if configured (default True for models without the attribute,
|
# Renormalize if configured (default True for models without the attribute,
|
||||||
# e.g. Mixtral/MiniMax which always normalize)
|
# e.g. Mixtral/MiniMax which always normalize)
|
||||||
if getattr(gate, "norm_topk_prob", True):
|
if getattr(gate, "norm_topk_prob", True):
|
||||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
|
||||||
# top_values = top_values.to(router_probs.dtype)
|
|
||||||
|
|
||||||
# Flatten for moe_general_routing_inputs.
|
# Flatten for moe_general_routing_inputs.
|
||||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||||
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
|
|||||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||||
|
|
||||||
|
# Aux-free bias: inject before group selection / topk
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
scores_for_choice = router_probs
|
scores_for_choice = router_probs
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = router_probs + afb_bias
|
||||||
|
|
||||||
# Group selection: pick top groups, mask the rest
|
# Group selection: pick top groups, mask the rest
|
||||||
if n_group > 1:
|
if n_group > 1:
|
||||||
@@ -164,6 +182,10 @@ def softmax_group_topk_routing(
|
|||||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Accumulate counts for aux-free bias update
|
||||||
|
if afb_bias is not None:
|
||||||
|
_accumulate_afb_counts(moe_block, topk_indices)
|
||||||
|
|
||||||
# Renormalization + scaling
|
# Renormalization + scaling
|
||||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||||
if norm_topk_prob:
|
if norm_topk_prob:
|
||||||
@@ -233,6 +255,11 @@ def sigmoid_topk_routing(
|
|||||||
)
|
)
|
||||||
scores_for_choice = router_probs + e_score_correction_bias
|
scores_for_choice = router_probs + e_score_correction_bias
|
||||||
|
|
||||||
|
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = scores_for_choice + afb_bias
|
||||||
|
|
||||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||||
if n_group > 1:
|
if n_group > 1:
|
||||||
group_scores = (
|
group_scores = (
|
||||||
@@ -256,6 +283,10 @@ def sigmoid_topk_routing(
|
|||||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Accumulate counts for aux-free bias update
|
||||||
|
if afb_bias is not None:
|
||||||
|
_accumulate_afb_counts(moe_block, topk_indices)
|
||||||
|
|
||||||
# Optional renormalization + scaling
|
# Optional renormalization + scaling
|
||||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||||
if norm_topk_prob:
|
if norm_topk_prob:
|
||||||
@@ -276,3 +307,19 @@ def sigmoid_topk_routing(
|
|||||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||||
|
|
||||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||||
|
"""Accumulate per-expert token counts for the aux-free bias update.
|
||||||
|
|
||||||
|
Called when ``moe_block._afb_bias`` is present (registered by the
|
||||||
|
``aux_free_router`` plugin). The counts are later consumed by the
|
||||||
|
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
||||||
|
"""
|
||||||
|
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||||
|
if afb_counts is None:
|
||||||
|
return
|
||||||
|
num_experts = afb_counts.numel()
|
||||||
|
flat_idx = topk_indices.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=num_experts)
|
||||||
|
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from importlib import util as importlib_util
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from importlib import util as importlib_util
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||||
@@ -96,16 +95,21 @@ def _build_bailing_model():
|
|||||||
|
|
||||||
|
|
||||||
def _build_llama4_model():
|
def _build_llama4_model():
|
||||||
from transformers import Llama4TextConfig
|
|
||||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||||
|
|
||||||
config = Llama4TextConfig(
|
# Build config without __post_init__ validation (works around a
|
||||||
|
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
||||||
|
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
||||||
|
config.__dict__.update(
|
||||||
hidden_size=16,
|
hidden_size=16,
|
||||||
intermediate_size=32,
|
intermediate_size=32,
|
||||||
num_local_experts=4,
|
num_local_experts=4,
|
||||||
num_attention_heads=2,
|
num_attention_heads=2,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
num_experts_per_tok=2,
|
num_experts_per_tok=2,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
hidden_act="silu",
|
||||||
|
layer_types=None,
|
||||||
)
|
)
|
||||||
layer = Llama4TextMoe(config)
|
layer = Llama4TextMoe(config)
|
||||||
|
|
||||||
@@ -148,6 +152,38 @@ def _build_mixtral_model():
|
|||||||
return DummyModel(layer), layer
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
|
def _build_qwen35_moe_model():
|
||||||
|
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
||||||
|
Qwen3_5MoeTextConfig,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||||
|
Qwen3_5MoeSparseMoeBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Qwen3_5MoeTextConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
moe_intermediate_size=32,
|
||||||
|
shared_expert_intermediate_size=32,
|
||||||
|
num_experts=4,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
)
|
||||||
|
layer = Qwen3_5MoeSparseMoeBlock(config)
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, moe_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe_layer
|
||||||
|
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.moe(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||||
if args is None:
|
if args is None:
|
||||||
args = SimpleNamespace(logging_steps=1)
|
args = SimpleNamespace(logging_steps=1)
|
||||||
@@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
|
|
||||||
_run_callback(plugin, cfg)
|
_run_callback(plugin, cfg)
|
||||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||||
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
self.assertFalse(
|
||||||
|
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
||||||
|
)
|
||||||
|
|
||||||
def test_llama4_adapter_biases_router_selection(self):
|
def test_llama4_adapter_biases_router_selection(self):
|
||||||
model, layer = _build_llama4_model()
|
model, layer = _build_llama4_model()
|
||||||
@@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
|
|
||||||
_run_callback(plugin, cfg)
|
_run_callback(plugin, cfg)
|
||||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
self.assertFalse(
|
||||||
|
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||||
|
)
|
||||||
|
|
||||||
def test_bias_warmup_respected(self):
|
def test_bias_warmup_respected(self):
|
||||||
model, block = _build_bailing_model()
|
model, block = _build_bailing_model()
|
||||||
@@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
|
|
||||||
# Warmup steps should leave bias untouched.
|
# Warmup steps should leave bias untouched.
|
||||||
_step()
|
_step()
|
||||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
self.assertTrue(
|
||||||
|
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||||
|
)
|
||||||
|
|
||||||
_step()
|
_step()
|
||||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
self.assertTrue(
|
||||||
|
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||||
|
)
|
||||||
|
|
||||||
# Third step exceeds warmup -> bias should update.
|
# Third step exceeds warmup -> bias should update.
|
||||||
_step()
|
_step()
|
||||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||||
|
|
||||||
def test_mixtral_adapter_respects_native_forward(self):
|
def test_mixtral_adapter_patches_router_not_forward(self):
|
||||||
|
"""Verify that aux-free patches the router (gate) only, and the
|
||||||
|
v5 block forward signature (single tensor return) is preserved."""
|
||||||
model, layer = _build_mixtral_model()
|
model, layer = _build_mixtral_model()
|
||||||
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
|
||||||
|
|
||||||
hidden_dim = layer.config.hidden_size
|
|
||||||
hidden = torch.randn(2, 3, hidden_dim)
|
|
||||||
baseline_out, baseline_logits = layer(hidden.clone())
|
|
||||||
|
|
||||||
cfg = _cfg()
|
cfg = _cfg()
|
||||||
plugin = AuxFreeMoEPlugin()
|
plugin = AuxFreeMoEPlugin()
|
||||||
plugin.post_model_build(cfg, model)
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
patched_out, patched_logits = layer(hidden.clone())
|
# Gate should be patched, not the block forward
|
||||||
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||||
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||||
|
|
||||||
|
# v5 block forward returns a single tensor (not a tuple with logits)
|
||||||
|
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
||||||
|
out = layer(hidden)
|
||||||
|
self.assertIsInstance(out, torch.Tensor)
|
||||||
|
self.assertEqual(out.shape, hidden.shape)
|
||||||
|
|
||||||
|
# Counts should have been accumulated
|
||||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
_run_callback(plugin, cfg)
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
|
def test_mixtral_adapter_bias_affects_selection(self):
|
||||||
|
"""When bias is large for one expert, it should be selected more often."""
|
||||||
|
model, layer = _build_mixtral_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
# Set a large bias for expert 0 to force its selection
|
||||||
|
layer._afb_bias.zero_()
|
||||||
|
layer._afb_bias[0] = 10.0
|
||||||
|
|
||||||
|
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
||||||
|
num_tokens = 2 * 8 # batch * seq
|
||||||
|
layer(hidden)
|
||||||
|
|
||||||
|
# With top_k=2, expert 0 should appear in every token's selection
|
||||||
|
# (once per token = num_tokens counts, not num_tokens * top_k)
|
||||||
|
counts = layer._afb_counts.clone()
|
||||||
|
self.assertEqual(
|
||||||
|
int(counts[0].item()),
|
||||||
|
num_tokens,
|
||||||
|
msg="Expert 0 should be selected for every token when heavily biased",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
||||||
|
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
||||||
|
output includes shared expert contribution."""
|
||||||
|
model, layer = _build_qwen35_moe_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
# Gate should be patched
|
||||||
|
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||||
|
# Shared expert should be unmodified
|
||||||
|
self.assertTrue(hasattr(layer, "shared_expert"))
|
||||||
|
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
||||||
|
|
||||||
|
# Forward should return a single tensor (shared + routed)
|
||||||
|
hidden_size = layer.gate.hidden_dim
|
||||||
|
hidden = torch.randn(2, 3, hidden_size)
|
||||||
|
out = layer(hidden)
|
||||||
|
self.assertIsInstance(out, torch.Tensor)
|
||||||
|
self.assertEqual(out.shape, hidden.shape)
|
||||||
|
|
||||||
|
# Counts should have been accumulated
|
||||||
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
|
||||||
|
def test_qwen35_moe_adapter_bias_updates(self):
|
||||||
|
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
||||||
|
model, layer = _build_qwen35_moe_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
hidden_size = layer.gate.hidden_dim
|
||||||
|
hidden = torch.randn(2, 4, hidden_size)
|
||||||
|
layer(hidden)
|
||||||
|
|
||||||
|
# Bias should start at zero
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
|
# After callback: counts reset, EMA updated, bias updated
|
||||||
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen35_moe_adapter_model_type_matching(self):
|
||||||
|
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
||||||
|
|
||||||
|
adapter = Qwen35MoeAdapter()
|
||||||
|
|
||||||
|
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
||||||
|
model_text = SimpleNamespace(
|
||||||
|
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
||||||
|
)
|
||||||
|
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
||||||
|
|
||||||
|
self.assertTrue(adapter.matches(model_moe))
|
||||||
|
self.assertTrue(adapter.matches(model_text))
|
||||||
|
self.assertFalse(adapter.matches(model_other))
|
||||||
|
|
||||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
@@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||||
tmp_init.close()
|
tmp_init.close()
|
||||||
init_method = f"file://{tmp_init.name}"
|
init_method = f"file://{tmp_init.name}"
|
||||||
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
dist.init_process_group(
|
||||||
|
backend="gloo", init_method=init_method, world_size=1, rank=0
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
block(hidden)
|
block(hidden)
|
||||||
@@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
|
|
||||||
def test_telemetry_logging(self):
|
def test_telemetry_logging(self):
|
||||||
model, layer = _build_mixtral_model()
|
model, layer = _build_mixtral_model()
|
||||||
layer.jitter_noise = 0.0
|
|
||||||
cfg = _cfg()
|
cfg = _cfg()
|
||||||
plugin = AuxFreeMoEPlugin()
|
plugin = AuxFreeMoEPlugin()
|
||||||
plugin.post_model_build(cfg, model)
|
plugin.post_model_build(cfg, model)
|
||||||
@@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||||
|
|
||||||
|
def test_get_num_experts_v5_attribute_paths(self):
|
||||||
|
"""Verify get_num_experts works with v5 attribute layout where
|
||||||
|
num_experts is on gate/experts sub-modules, not the block."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||||
|
|
||||||
|
adapter = MixtralAdapter()
|
||||||
|
|
||||||
|
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
||||||
|
block = SimpleNamespace(
|
||||||
|
gate=SimpleNamespace(num_experts=8),
|
||||||
|
experts=SimpleNamespace(num_experts=8),
|
||||||
|
)
|
||||||
|
self.assertEqual(adapter.get_num_experts(block), 8)
|
||||||
|
|
||||||
|
# Also works when num_experts is directly on block
|
||||||
|
block2 = SimpleNamespace(num_experts=4)
|
||||||
|
self.assertEqual(adapter.get_num_experts(block2), 4)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxFreeKernelComposition(unittest.TestCase):
|
||||||
|
"""Tests that aux-free bias composes correctly with kernel routing."""
|
||||||
|
|
||||||
|
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
||||||
|
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||||
|
|
||||||
|
num_experts = 4
|
||||||
|
top_k = 2
|
||||||
|
hidden_dim = 16
|
||||||
|
T = 6
|
||||||
|
|
||||||
|
# Build a mock MoE block with gate attributes
|
||||||
|
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||||
|
gate.top_k = top_k
|
||||||
|
gate.num_experts = num_experts
|
||||||
|
gate.norm_topk_prob = True
|
||||||
|
|
||||||
|
moe_block = SimpleNamespace(gate=gate)
|
||||||
|
hidden = torch.randn(T, hidden_dim)
|
||||||
|
|
||||||
|
# Baseline: no bias
|
||||||
|
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
||||||
|
hidden, moe_block
|
||||||
|
)
|
||||||
|
self.assertEqual(scores_base.shape[0], T * top_k)
|
||||||
|
|
||||||
|
# Now register aux-free buffers and set heavy bias on expert 0
|
||||||
|
moe_block._afb_bias = torch.zeros(num_experts)
|
||||||
|
moe_block._afb_bias[0] = 100.0
|
||||||
|
moe_block._afb_counts = torch.zeros(num_experts)
|
||||||
|
|
||||||
|
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
||||||
|
hidden, moe_block
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert 0 should be selected for every token
|
||||||
|
self.assertTrue(
|
||||||
|
(exp_biased == 0).any(),
|
||||||
|
"Expert 0 should appear in selections when heavily biased",
|
||||||
|
)
|
||||||
|
# Counts should have been accumulated
|
||||||
|
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||||
|
# Total counts should equal T * top_k
|
||||||
|
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||||
|
|
||||||
|
def test_sonicmoe_routing_without_bias_unchanged(self):
|
||||||
|
"""Without _afb_bias, routing should produce identical results."""
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||||
|
|
||||||
|
num_experts = 4
|
||||||
|
top_k = 2
|
||||||
|
hidden_dim = 16
|
||||||
|
|
||||||
|
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||||
|
gate.top_k = top_k
|
||||||
|
gate.num_experts = num_experts
|
||||||
|
gate.norm_topk_prob = True
|
||||||
|
|
||||||
|
moe_block = SimpleNamespace(gate=gate)
|
||||||
|
hidden = torch.randn(4, hidden_dim)
|
||||||
|
|
||||||
|
# Without _afb_bias attribute
|
||||||
|
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
# With _afb_bias = zeros (should be equivalent)
|
||||||
|
moe_block._afb_bias = torch.zeros(num_experts)
|
||||||
|
moe_block._afb_counts = torch.zeros(num_experts)
|
||||||
|
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
torch.testing.assert_close(scores1, scores2)
|
||||||
|
torch.testing.assert_close(exp1, exp2)
|
||||||
|
|
||||||
|
@unittest.skipUnless(
|
||||||
|
importlib_util.find_spec("triton") is not None,
|
||||||
|
"triton not installed (required by scattermoe)",
|
||||||
|
)
|
||||||
|
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
||||||
|
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||||
|
_softmax_topk_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_experts = 4
|
||||||
|
top_k = 2
|
||||||
|
hidden_dim = 16
|
||||||
|
T = 6
|
||||||
|
|
||||||
|
gate_weight = torch.randn(num_experts, hidden_dim)
|
||||||
|
base_gate = SimpleNamespace(
|
||||||
|
top_k=top_k,
|
||||||
|
num_experts=num_experts,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
weight=gate_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_block = SimpleNamespace()
|
||||||
|
hidden = torch.randn(T, hidden_dim)
|
||||||
|
|
||||||
|
# Baseline without bias
|
||||||
|
w_base, e_base, _, _ = _softmax_topk_route(
|
||||||
|
moe_block, base_gate, hidden, gate_weight, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# With heavy bias on expert 0
|
||||||
|
moe_block._afb_bias = torch.zeros(num_experts)
|
||||||
|
moe_block._afb_bias[0] = 100.0
|
||||||
|
moe_block._afb_counts = torch.zeros(num_experts)
|
||||||
|
|
||||||
|
w_biased, e_biased, _, _ = _softmax_topk_route(
|
||||||
|
moe_block, base_gate, hidden, gate_weight, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert 0 should appear in all selections
|
||||||
|
self.assertTrue((e_biased == 0).any())
|
||||||
|
# Counts accumulated
|
||||||
|
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||||
|
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||||
|
|
||||||
|
def test_kernel_routing_skips_router_patch(self):
|
||||||
|
"""When a kernel backend has patched the block class, the adapter
|
||||||
|
should skip patching the router (buffers are still registered)."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||||
|
|
||||||
|
adapter = MixtralAdapter()
|
||||||
|
|
||||||
|
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
||||||
|
class PatchedBlock(nn.Module):
|
||||||
|
_original_forward = True # SonicMoE marker
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Linear(16, 4, bias=False)
|
||||||
|
self.gate.top_k = 2
|
||||||
|
self.gate.num_experts = 4
|
||||||
|
self.gate.hidden_dim = 16
|
||||||
|
self.experts = nn.Linear(16, 16) # placeholder
|
||||||
|
|
||||||
|
layer = PatchedBlock()
|
||||||
|
self.assertTrue(adapter.uses_kernel_routing(layer))
|
||||||
|
|
||||||
|
# Gate should NOT be patched (kernel handles routing)
|
||||||
|
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
|
||||||
|
def test_adapter_buffers_registered_even_with_kernel(self):
|
||||||
|
"""Even when kernel routing is active, aux-free buffers must be
|
||||||
|
registered on the MoE block so the kernel routing can find them."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import (
|
||||||
|
LayerHandle,
|
||||||
|
MixtralAdapter,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.aux_free_router.core import (
|
||||||
|
AuxFreeConfig,
|
||||||
|
AuxFreeShim,
|
||||||
|
AuxFreeState,
|
||||||
|
)
|
||||||
|
|
||||||
|
class PatchedBlock(nn.Module):
|
||||||
|
_original_forward = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Linear(16, 4, bias=False)
|
||||||
|
self.gate.top_k = 2
|
||||||
|
self.gate.num_experts = 4
|
||||||
|
self.gate.hidden_dim = 16
|
||||||
|
self.experts = nn.Linear(16, 16)
|
||||||
|
|
||||||
|
layer = PatchedBlock()
|
||||||
|
adapter = MixtralAdapter()
|
||||||
|
cfg = AuxFreeConfig()
|
||||||
|
state = AuxFreeState(
|
||||||
|
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
||||||
|
)
|
||||||
|
shim = AuxFreeShim(state=state)
|
||||||
|
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
||||||
|
|
||||||
|
adapter.prepare(layer, handle, shim)
|
||||||
|
|
||||||
|
# Buffers should be registered for kernel routing to use
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_counts"))
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_ema"))
|
||||||
|
# But gate should NOT be patched
|
||||||
|
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user