aux_free_router: sync shim state

- drive warmup-aware bias updates and register live buffers
This commit is contained in:
lhl
2025-10-28 08:08:00 +00:00
committed by Wing Lian
parent 2af7475fdf
commit a0019021dd
3 changed files with 153 additions and 27 deletions

View File

@@ -50,7 +50,7 @@ class BaseMoEAdapter:
except Exception: # pragma: no cover - non-critical
pass
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle) -> None:
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
@@ -60,10 +60,11 @@ class BaseMoEAdapter:
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
shim.register_layer_buffers(handle.layer_idx, moe_layer)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled."""
self._register_aux_buffers(moe_layer, handle)
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
@@ -122,11 +123,60 @@ class MixtralAdapter(BaseMoEAdapter):
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_mixtral_forward(moe_layer, shim)
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
if getattr(moe_layer, "_afb_patched", False):
return
shim_ref = shim
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and getattr(self, "jitter_noise", 0) > 0:
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
flat_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(flat_states)
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
top_k = int(getattr(self, "_afb_top_k", self.top_k))
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
routing_weights = routing_weights.to(flat_states.dtype)
flat_idx = selected_experts.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=flat_states.dtype,
device=flat_states.device,
)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe"
@@ -154,7 +204,7 @@ class BailingAdapter(BaseMoEAdapter):
return int(getattr(cfg, "num_experts"))
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle)
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_bailing_gate(moe_layer)
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
@@ -200,7 +250,7 @@ class Llama4Adapter(BaseMoEAdapter):
yield m
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle)
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_llama4_router(moe_layer)
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:

View File

@@ -5,6 +5,9 @@ from typing import Optional
import torch
import torch.distributed as dist
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@dataclass
@@ -29,22 +32,52 @@ class AuxFreeState:
class AuxFreeShim:
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
def __init__(self, state: AuxFreeState, ep_group: Optional[dist.ProcessGroup] = None):
def __init__(
self,
state: AuxFreeState,
ep_group: Optional[dist.ProcessGroup] = None,
ep_size: Optional[int] = None,
):
self.state = state
self.ep_group = ep_group
self._ep_size = ep_size
self._ep_group_pending = (
self.state.cfg.sync_group == "ep" and self.ep_group is None
)
self._layer_modules: dict[int, torch.nn.Module] = {}
@torch.no_grad()
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
b = self.state.bias[layer_idx]
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_bias"):
b = getattr(module, "_afb_bias")
else:
b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
"""Bind model buffers so shim updates stay in sync with patched layers."""
self._layer_modules[layer_idx] = module
bias = getattr(module, "_afb_bias")
ema = getattr(module, "_afb_ema")
# Keep state views pointing to the same tensors to avoid drift.
if layer_idx < len(self.state.bias):
self.state.bias[layer_idx] = bias
if layer_idx < len(self.state.ema_load):
self.state.ema_load[layer_idx] = ema
def begin_step(self) -> None:
"""Call once per optimizer step before per-layer updates."""
self.state.steps += 1
@torch.no_grad()
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
self._maybe_init_ep_group()
if not dist.is_available() or not dist.is_initialized():
return counts
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
@@ -55,22 +88,63 @@ class AuxFreeShim:
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
cfg = self.state.cfg
self.state.steps += 1
if self.state.steps <= cfg.warmup_steps:
return
nE = step_counts.numel()
if tokens_seen <= 0:
return
freq = step_counts.float() / float(tokens_seen)
ema = self.state.ema_load[layer_idx]
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_ema"):
ema = getattr(module, "_afb_ema")
bias = getattr(module, "_afb_bias")
else:
ema = self.state.ema_load[layer_idx]
bias = self.state.bias[layer_idx]
counts = step_counts.to(ema.device)
freq = counts.float() / float(tokens_seen)
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
# optional mean-centering to keep sum(bias) ~ 0
delta = delta - delta.mean()
bias = self.state.bias[layer_idx]
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
def _maybe_init_ep_group(self) -> None:
if not self._ep_group_pending:
return
if not dist.is_available() or not dist.is_initialized():
return
ep_size = self._ep_size
if not ep_size or ep_size <= 1:
LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
ep_size,
world,
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
if ep_size == world:
self.ep_group = dist.group.WORLD
else:
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
self.ep_group = dist.new_group(ranks)
LOG.info(
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
ep_size,
dist.get_world_size(),
)
self._ep_group_pending = False

View File

@@ -42,29 +42,23 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
# Iterate prepared MoE layers and apply the bias update rule.
cfg = self.shim.state.cfg
self.shim.begin_step()
for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
continue
counts = getattr(layer, "_afb_counts")
if counts is None:
continue
counts = counts.to(counts.device)
counts = self.shim.all_reduce_counts(counts)
tokens_seen = int(counts.sum().item())
layer_idx = getattr(layer, "_afb_layer_idx", None)
if layer_idx is None:
counts.zero_()
continue
bias = getattr(layer, "_afb_bias")
counts_for_update = counts.to(bias.device)
tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update
if tokens_seen > 0:
freq = counts.float() / float(tokens_seen)
ema = getattr(layer, "_afb_ema")
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
nE = counts.numel()
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
delta = delta - delta.mean()
bias = getattr(layer, "_afb_bias")
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
# reset step counts
counts.zero_()
return control
@@ -125,8 +119,16 @@ class AuxFreeMoEPlugin(BasePlugin):
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
ep_group = self._resolve_ep_group(cfg) if sync_group == "ep" else None
self._shim = AuxFreeShim(state=state, ep_group=ep_group)
ep_size = getattr(cfg, "expert_parallel_size", None)
ep_group = None
if sync_group == "ep":
if dist.is_available() and dist.is_initialized():
ep_group = self._resolve_ep_group(cfg)
else:
LOG.info(
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
)
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
# Discover and prepare layers (attach per-layer buffers)
self._handles = discover_and_prepare_layers(model, adapters, self._shim)