diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py index cbcf84618..e349c268b 100644 --- a/src/axolotl/integrations/aux_free_router/adapters.py +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -50,7 +50,7 @@ class BaseMoEAdapter: except Exception: # pragma: no cover - non-critical pass - def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle) -> None: + def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: device = next(moe_layer.parameters(), torch.tensor(0)).device if not hasattr(moe_layer, "_afb_bias"): moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device)) @@ -60,10 +60,11 @@ class BaseMoEAdapter: moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device)) moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined] moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined] + shim.register_layer_buffers(handle.layer_idx, moe_layer) def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: """Attach per-layer buffers and mark as aux-free enabled.""" - self._register_aux_buffers(moe_layer, handle) + self._register_aux_buffers(moe_layer, handle, shim) self._patch_forward_with_aux_free(moe_layer) def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None: @@ -122,11 +123,60 @@ class MixtralAdapter(BaseMoEAdapter): def matches(self, model: nn.Module) -> bool: return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral" + def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + self._register_aux_buffers(moe_layer, handle, shim) + self._patch_mixtral_forward(moe_layer, shim) + def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: for m in model.modules(): if m.__class__.__name__.endswith("SparseMoeBlock"): yield m + def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None: + if getattr(moe_layer, "_afb_patched", False): + return + + shim_ref = shim + + def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef] + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and getattr(self, "jitter_noise", 0) > 0: + hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + flat_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(flat_states) + + layer_idx = int(getattr(self, "_afb_layer_idx", 0)) + top_k = int(getattr(self, "_afb_top_k", self.top_k)) + selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k) + routing_weights = routing_weights.to(flat_states.dtype) + + flat_idx = selected_experts.reshape(-1) + counts = torch.bincount(flat_idx, minlength=int(self.num_experts)) + self._afb_counts.add_(counts.to(self._afb_counts.dtype)) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=flat_states.dtype, + device=flat_states.device, + ) + + expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + current_state = flat_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined] + setattr(moe_layer, "_afb_patched", True) + class Qwen3Adapter(MixtralAdapter): family = "qwen3_moe" @@ -154,7 +204,7 @@ class BailingAdapter(BaseMoEAdapter): return int(getattr(cfg, "num_experts")) def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: - self._register_aux_buffers(moe_layer, handle) + self._register_aux_buffers(moe_layer, handle, shim) self._patch_bailing_gate(moe_layer) def _patch_bailing_gate(self, moe_layer: nn.Module) -> None: @@ -200,7 +250,7 @@ class Llama4Adapter(BaseMoEAdapter): yield m def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: - self._register_aux_buffers(moe_layer, handle) + self._register_aux_buffers(moe_layer, handle, shim) self._patch_llama4_router(moe_layer) def _patch_llama4_router(self, moe_layer: nn.Module) -> None: diff --git a/src/axolotl/integrations/aux_free_router/core.py b/src/axolotl/integrations/aux_free_router/core.py index a673a9856..66a94689d 100644 --- a/src/axolotl/integrations/aux_free_router/core.py +++ b/src/axolotl/integrations/aux_free_router/core.py @@ -5,6 +5,9 @@ from typing import Optional import torch import torch.distributed as dist +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) @dataclass @@ -29,22 +32,52 @@ class AuxFreeState: class AuxFreeShim: """Model-agnostic shim for aux-loss-free expert selection and bias updates.""" - def __init__(self, state: AuxFreeState, ep_group: Optional[dist.ProcessGroup] = None): + def __init__( + self, + state: AuxFreeState, + ep_group: Optional[dist.ProcessGroup] = None, + ep_size: Optional[int] = None, + ): self.state = state self.ep_group = ep_group + self._ep_size = ep_size + self._ep_group_pending = ( + self.state.cfg.sync_group == "ep" and self.ep_group is None + ) + self._layer_modules: dict[int, torch.nn.Module] = {} @torch.no_grad() def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: """Returns (topk_indices, weights) using biased selection and unbiased weights.""" - b = self.state.bias[layer_idx] + module = self._layer_modules.get(layer_idx) + if module is not None and hasattr(module, "_afb_bias"): + b = getattr(module, "_afb_bias") + else: + b = self.state.bias[layer_idx] biased = logits + b # bias is a buffer topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1) chosen_logits = torch.gather(logits, -1, topk_idx) weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype) return topk_idx, weights + def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None: + """Bind model buffers so shim updates stay in sync with patched layers.""" + self._layer_modules[layer_idx] = module + bias = getattr(module, "_afb_bias") + ema = getattr(module, "_afb_ema") + # Keep state views pointing to the same tensors to avoid drift. + if layer_idx < len(self.state.bias): + self.state.bias[layer_idx] = bias + if layer_idx < len(self.state.ema_load): + self.state.ema_load[layer_idx] = ema + + def begin_step(self) -> None: + """Call once per optimizer step before per-layer updates.""" + self.state.steps += 1 + @torch.no_grad() def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor: + self._maybe_init_ep_group() if not dist.is_available() or not dist.is_initialized(): return counts group = self.ep_group if self.ep_group is not None else dist.group.WORLD @@ -55,22 +88,63 @@ class AuxFreeShim: def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int): """Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering.""" cfg = self.state.cfg - self.state.steps += 1 if self.state.steps <= cfg.warmup_steps: return nE = step_counts.numel() if tokens_seen <= 0: return - freq = step_counts.float() / float(tokens_seen) - ema = self.state.ema_load[layer_idx] + module = self._layer_modules.get(layer_idx) + if module is not None and hasattr(module, "_afb_ema"): + ema = getattr(module, "_afb_ema") + bias = getattr(module, "_afb_bias") + else: + ema = self.state.ema_load[layer_idx] + bias = self.state.bias[layer_idx] + counts = step_counts.to(ema.device) + freq = counts.float() / float(tokens_seen) ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq) target = 1.0 / float(nE) delta = cfg.rate * (target - ema) # optional mean-centering to keep sum(bias) ~ 0 delta = delta - delta.mean() - bias = self.state.bias[layer_idx] bias.add_(delta) if cfg.bias_cap is not None and cfg.bias_cap > 0: bias.clamp_(-cfg.bias_cap, cfg.bias_cap) + def _maybe_init_ep_group(self) -> None: + if not self._ep_group_pending: + return + if not dist.is_available() or not dist.is_initialized(): + return + ep_size = self._ep_size + if not ep_size or ep_size <= 1: + LOG.warning( + "AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group" + ) + self.ep_group = dist.group.WORLD + self._ep_group_pending = False + return + world = dist.get_world_size() + if world % ep_size != 0: + LOG.warning( + "AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group", + ep_size, + world, + ) + self.ep_group = dist.group.WORLD + self._ep_group_pending = False + return + if ep_size == world: + self.ep_group = dist.group.WORLD + else: + rank = dist.get_rank() + group_start = (rank // ep_size) * ep_size + ranks = tuple(range(group_start, group_start + ep_size)) + self.ep_group = dist.new_group(ranks) + LOG.info( + "AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)", + ep_size, + dist.get_world_size(), + ) + self._ep_group_pending = False diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index 6f1c2a633..fc2280032 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -42,29 +42,23 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): # noqa: D401 # Iterate prepared MoE layers and apply the bias update rule. - cfg = self.shim.state.cfg + self.shim.begin_step() for layer in self.layer_modules: if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"): continue counts = getattr(layer, "_afb_counts") if counts is None: continue - counts = counts.to(counts.device) counts = self.shim.all_reduce_counts(counts) - tokens_seen = int(counts.sum().item()) + layer_idx = getattr(layer, "_afb_layer_idx", None) + if layer_idx is None: + counts.zero_() + continue + bias = getattr(layer, "_afb_bias") + counts_for_update = counts.to(bias.device) + tokens_seen = int(counts_for_update.sum().item()) # local layer-state EMA and bias update - if tokens_seen > 0: - freq = counts.float() / float(tokens_seen) - ema = getattr(layer, "_afb_ema") - ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq) - nE = counts.numel() - target = 1.0 / float(nE) - delta = cfg.rate * (target - ema) - delta = delta - delta.mean() - bias = getattr(layer, "_afb_bias") - bias.add_(delta) - if cfg.bias_cap is not None and cfg.bias_cap > 0: - bias.clamp_(-cfg.bias_cap, cfg.bias_cap) + self.shim.update_bias(layer_idx, counts_for_update, tokens_seen) # reset step counts counts.zero_() return control @@ -125,8 +119,16 @@ class AuxFreeMoEPlugin(BasePlugin): # we'll set a minimal placeholder; prepare() will conceptually use module buffers instead n_experts = 2 state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg) - ep_group = self._resolve_ep_group(cfg) if sync_group == "ep" else None - self._shim = AuxFreeShim(state=state, ep_group=ep_group) + ep_size = getattr(cfg, "expert_parallel_size", None) + ep_group = None + if sync_group == "ep": + if dist.is_available() and dist.is_initialized(): + ep_group = self._resolve_ep_group(cfg) + else: + LOG.info( + "AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes" + ) + self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size) # Discover and prepare layers (attach per-layer buffers) self._handles = discover_and_prepare_layers(model, adapters, self._shim)