aux_free_router: sync shim state
- drive warmup-aware bias updates and register live buffers
This commit is contained in:
@@ -50,7 +50,7 @@ class BaseMoEAdapter:
|
||||
except Exception: # pragma: no cover - non-critical
|
||||
pass
|
||||
|
||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle) -> None:
|
||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
||||
if not hasattr(moe_layer, "_afb_bias"):
|
||||
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
||||
@@ -60,10 +60,11 @@ class BaseMoEAdapter:
|
||||
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
"""Attach per-layer buffers and mark as aux-free enabled."""
|
||||
self._register_aux_buffers(moe_layer, handle)
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_forward_with_aux_free(moe_layer)
|
||||
|
||||
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
||||
@@ -122,11 +123,60 @@ class MixtralAdapter(BaseMoEAdapter):
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_mixtral_forward(moe_layer, shim)
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||
yield m
|
||||
|
||||
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
|
||||
if getattr(moe_layer, "_afb_patched", False):
|
||||
return
|
||||
|
||||
shim_ref = shim
|
||||
|
||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
if self.training and getattr(self, "jitter_noise", 0) > 0:
|
||||
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
||||
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
||||
)
|
||||
flat_states = hidden_states.view(-1, hidden_dim)
|
||||
router_logits = self.gate(flat_states)
|
||||
|
||||
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
|
||||
top_k = int(getattr(self, "_afb_top_k", self.top_k))
|
||||
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
|
||||
routing_weights = routing_weights.to(flat_states.dtype)
|
||||
|
||||
flat_idx = selected_experts.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=flat_states.dtype,
|
||||
device=flat_states.device,
|
||||
)
|
||||
|
||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||
setattr(moe_layer, "_afb_patched", True)
|
||||
|
||||
|
||||
class Qwen3Adapter(MixtralAdapter):
|
||||
family = "qwen3_moe"
|
||||
@@ -154,7 +204,7 @@ class BailingAdapter(BaseMoEAdapter):
|
||||
return int(getattr(cfg, "num_experts"))
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle)
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_bailing_gate(moe_layer)
|
||||
|
||||
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||
@@ -200,7 +250,7 @@ class Llama4Adapter(BaseMoEAdapter):
|
||||
yield m
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle)
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_llama4_router(moe_layer)
|
||||
|
||||
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||
|
||||
@@ -5,6 +5,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,22 +32,52 @@ class AuxFreeState:
|
||||
class AuxFreeShim:
|
||||
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
||||
|
||||
def __init__(self, state: AuxFreeState, ep_group: Optional[dist.ProcessGroup] = None):
|
||||
def __init__(
|
||||
self,
|
||||
state: AuxFreeState,
|
||||
ep_group: Optional[dist.ProcessGroup] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.ep_group = ep_group
|
||||
self._ep_size = ep_size
|
||||
self._ep_group_pending = (
|
||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||
)
|
||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||
b = self.state.bias[layer_idx]
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_bias"):
|
||||
b = getattr(module, "_afb_bias")
|
||||
else:
|
||||
b = self.state.bias[layer_idx]
|
||||
biased = logits + b # bias is a buffer
|
||||
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||
return topk_idx, weights
|
||||
|
||||
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||
self._layer_modules[layer_idx] = module
|
||||
bias = getattr(module, "_afb_bias")
|
||||
ema = getattr(module, "_afb_ema")
|
||||
# Keep state views pointing to the same tensors to avoid drift.
|
||||
if layer_idx < len(self.state.bias):
|
||||
self.state.bias[layer_idx] = bias
|
||||
if layer_idx < len(self.state.ema_load):
|
||||
self.state.ema_load[layer_idx] = ema
|
||||
|
||||
def begin_step(self) -> None:
|
||||
"""Call once per optimizer step before per-layer updates."""
|
||||
self.state.steps += 1
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||
self._maybe_init_ep_group()
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return counts
|
||||
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
||||
@@ -55,22 +88,63 @@ class AuxFreeShim:
|
||||
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
||||
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
||||
cfg = self.state.cfg
|
||||
self.state.steps += 1
|
||||
if self.state.steps <= cfg.warmup_steps:
|
||||
return
|
||||
|
||||
nE = step_counts.numel()
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
freq = step_counts.float() / float(tokens_seen)
|
||||
ema = self.state.ema_load[layer_idx]
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_ema"):
|
||||
ema = getattr(module, "_afb_ema")
|
||||
bias = getattr(module, "_afb_bias")
|
||||
else:
|
||||
ema = self.state.ema_load[layer_idx]
|
||||
bias = self.state.bias[layer_idx]
|
||||
counts = step_counts.to(ema.device)
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||
target = 1.0 / float(nE)
|
||||
delta = cfg.rate * (target - ema)
|
||||
# optional mean-centering to keep sum(bias) ~ 0
|
||||
delta = delta - delta.mean()
|
||||
bias = self.state.bias[layer_idx]
|
||||
bias.add_(delta)
|
||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||
|
||||
def _maybe_init_ep_group(self) -> None:
|
||||
if not self._ep_group_pending:
|
||||
return
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return
|
||||
ep_size = self._ep_size
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
if ep_size == world:
|
||||
self.ep_group = dist.group.WORLD
|
||||
else:
|
||||
rank = dist.get_rank()
|
||||
group_start = (rank // ep_size) * ep_size
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
self.ep_group = dist.new_group(ranks)
|
||||
LOG.info(
|
||||
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
||||
ep_size,
|
||||
dist.get_world_size(),
|
||||
)
|
||||
self._ep_group_pending = False
|
||||
|
||||
@@ -42,29 +42,23 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||
# Iterate prepared MoE layers and apply the bias update rule.
|
||||
cfg = self.shim.state.cfg
|
||||
self.shim.begin_step()
|
||||
for layer in self.layer_modules:
|
||||
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
|
||||
continue
|
||||
counts = getattr(layer, "_afb_counts")
|
||||
if counts is None:
|
||||
continue
|
||||
counts = counts.to(counts.device)
|
||||
counts = self.shim.all_reduce_counts(counts)
|
||||
tokens_seen = int(counts.sum().item())
|
||||
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
||||
if layer_idx is None:
|
||||
counts.zero_()
|
||||
continue
|
||||
bias = getattr(layer, "_afb_bias")
|
||||
counts_for_update = counts.to(bias.device)
|
||||
tokens_seen = int(counts_for_update.sum().item())
|
||||
# local layer-state EMA and bias update
|
||||
if tokens_seen > 0:
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
ema = getattr(layer, "_afb_ema")
|
||||
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||
nE = counts.numel()
|
||||
target = 1.0 / float(nE)
|
||||
delta = cfg.rate * (target - ema)
|
||||
delta = delta - delta.mean()
|
||||
bias = getattr(layer, "_afb_bias")
|
||||
bias.add_(delta)
|
||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||
# reset step counts
|
||||
counts.zero_()
|
||||
return control
|
||||
@@ -125,8 +119,16 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||
n_experts = 2
|
||||
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
||||
ep_group = self._resolve_ep_group(cfg) if sync_group == "ep" else None
|
||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group)
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
ep_group = None
|
||||
if sync_group == "ep":
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
ep_group = self._resolve_ep_group(cfg)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
||||
)
|
||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
||||
|
||||
# Discover and prepare layers (attach per-layer buffers)
|
||||
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||
|
||||
Reference in New Issue
Block a user