From 6eac9ac372ddedb09a50203bcc23871fc5090f88 Mon Sep 17 00:00:00 2001 From: lhl Date: Tue, 28 Oct 2025 08:27:48 +0000 Subject: [PATCH] aux_free_router: emit telemetry metrics - log per-layer load stats, bias magnitude, and oscillation rate - honor configurable telemetry interval with Trainer logging integration --- .../integrations/aux_free_router/README.md | 3 +- .../integrations/aux_free_router/core.py | 6 ++ .../integrations/aux_free_router/plugin.py | 86 ++++++++++++++++++- tests/unit/test_aux_free_adapters.py | 71 +++++++++++---- 4 files changed, 147 insertions(+), 19 deletions(-) diff --git a/src/axolotl/integrations/aux_free_router/README.md b/src/axolotl/integrations/aux_free_router/README.md index f254e253e..1b77e49eb 100644 --- a/src/axolotl/integrations/aux_free_router/README.md +++ b/src/axolotl/integrations/aux_free_router/README.md @@ -30,6 +30,7 @@ Config keys - moe_afb_warmup_steps: delay before applying updates. Default: 0. - moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world. - expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world). +- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset). Compatibility - Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers. @@ -37,4 +38,4 @@ Compatibility Notes - If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on. -- Telemetry: future updates will log per-expert loads and bias magnitudes. +- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval. diff --git a/src/axolotl/integrations/aux_free_router/core.py b/src/axolotl/integrations/aux_free_router/core.py index 66a94689d..9012c4844 100644 --- a/src/axolotl/integrations/aux_free_router/core.py +++ b/src/axolotl/integrations/aux_free_router/core.py @@ -17,6 +17,7 @@ class AuxFreeConfig: bias_cap: float = 2.0 warmup_steps: int = 0 sync_group: str = "world" # or "ep" + telemetry_interval: Optional[int] = None class AuxFreeState: @@ -45,6 +46,7 @@ class AuxFreeShim: self.state.cfg.sync_group == "ep" and self.ep_group is None ) self._layer_modules: dict[int, torch.nn.Module] = {} + self._prev_bias_sign: dict[int, torch.Tensor] = {} @torch.no_grad() def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: @@ -75,6 +77,9 @@ class AuxFreeShim: """Call once per optimizer step before per-layer updates.""" self.state.steps += 1 + def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]: + return self._prev_bias_sign.get(layer_idx) + @torch.no_grad() def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor: self._maybe_init_ep_group() @@ -111,6 +116,7 @@ class AuxFreeShim: bias.add_(delta) if cfg.bias_cap is not None and cfg.bias_cap > 0: bias.clamp_(-cfg.bias_cap, cfg.bias_cap) + self._prev_bias_sign[layer_idx] = torch.sign(bias.detach()) def _maybe_init_ep_group(self) -> None: if not self._ep_group_pending: diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index fc2280032..3dc456222 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -36,9 +36,17 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): routing patches in a follow-up). """ - def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]): + def __init__( + self, + shim: AuxFreeShim, + layer_modules: list[torch.nn.Module], + telemetry_interval: Optional[int] = None, + ): self.shim = shim self.layer_modules = layer_modules + self.telemetry_interval = telemetry_interval + self._prev_bias_sign: dict[int, torch.Tensor] = {} + self._telemetry_buffer: dict[int, dict[str, float]] = {} def on_step_end(self, args, state, control, **kwargs): # noqa: D401 # Iterate prepared MoE layers and apply the bias update rule. @@ -59,10 +67,72 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): tokens_seen = int(counts_for_update.sum().item()) # local layer-state EMA and bias update self.shim.update_bias(layer_idx, counts_for_update, tokens_seen) + self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias) # reset step counts counts.zero_() + + if self._should_log(args, state): + if self._telemetry_buffer: + if not hasattr(state, "log_history"): + state.log_history = [] + log_entry = {"step": state.global_step} + for layer_idx, metrics in sorted(self._telemetry_buffer.items()): + prefix = f"moe_afb/l{layer_idx}_" + for key, value in metrics.items(): + log_entry[f"{prefix}{key}"] = value + state.log_history.append(log_entry) + control.should_log = True + self._telemetry_buffer.clear() return control + def _collect_telemetry( + self, + layer_idx: int, + counts: torch.Tensor, + tokens_seen: int, + bias: torch.Tensor, + ) -> None: + if tokens_seen <= 0: + return + freq = counts.float() / float(tokens_seen) + load_min = freq.min().item() + load_mean = freq.mean().item() + load_max = freq.max().item() + bias_abs_max = bias.abs().max().item() + + prev_sign = self._prev_bias_sign.get(layer_idx) + current_sign = torch.sign(bias.detach()) + if prev_sign is None or prev_sign.shape != current_sign.shape: + oscillation = 0.0 + else: + changed = (current_sign != prev_sign) & ( + (current_sign != 0) | (prev_sign != 0) + ) + if changed.numel() == 0: + oscillation = 0.0 + else: + oscillation = changed.float().mean().item() + self._prev_bias_sign[layer_idx] = current_sign.clone() + + self._telemetry_buffer[layer_idx] = { + "load_min": load_min, + "load_mean": load_mean, + "load_max": load_max, + "bias_abs_max": bias_abs_max, + "bias_sign_flip_frac": oscillation, + } + + def _should_log(self, args, state) -> bool: + interval = ( + self.telemetry_interval + if self.telemetry_interval is not None and self.telemetry_interval > 0 + else getattr(args, "logging_steps", 0) + ) + interval = max(1, int(interval)) if interval else 0 + if interval == 0: + return False + return state.global_step % interval == 0 + class AuxFreeMoEPlugin(BasePlugin): """Plugin that enables aux-loss-free routing when configured.""" @@ -95,8 +165,14 @@ class AuxFreeMoEPlugin(BasePlugin): bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0 warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0 sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world" + telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None) af_cfg = AuxFreeConfig( - rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group + rate=rate, + momentum=momentum, + bias_cap=bias_cap, + warmup_steps=warmup, + sync_group=sync_group, + telemetry_interval=telemetry_interval, ) # Discover layers to count the number and experts for state sizing @@ -170,6 +246,10 @@ class AuxFreeMoEPlugin(BasePlugin): return [] # gather concrete layer modules from handles layers = [h.layer for h in self._handles] - cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers) + cb = MoeAuxFreeBiasUpdateCallback( + self._shim, + layers, + telemetry_interval=self._shim.state.cfg.telemetry_interval, + ) LOG.info("AuxFreeMoE: registering post-step bias update callback") return [cb] diff --git a/tests/unit/test_aux_free_adapters.py b/tests/unit/test_aux_free_adapters.py index 214126892..b7889fcfa 100644 --- a/tests/unit/test_aux_free_adapters.py +++ b/tests/unit/test_aux_free_adapters.py @@ -148,12 +148,23 @@ def _build_mixtral_model(): return DummyModel(layer), layer -def _run_callback(plugin, cfg): +def _run_callback(plugin, cfg, *, args=None, state=None, control=None): callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) assert callbacks, "expected aux-free callback to be registered" callback = callbacks[0] - dummy = SimpleNamespace() - callback.on_step_end(args=dummy, state=dummy, control=dummy) + if args is None: + args = SimpleNamespace(logging_steps=1) + if state is None: + state = SimpleNamespace(global_step=1, log_history=[]) + if control is None: + control = SimpleNamespace( + should_log=False, + should_evaluate=False, + should_save=False, + should_training_stop=False, + ) + callback.on_step_end(args=args, state=state, control=control) + return state, control class TestAuxFreeAdapters(unittest.TestCase): @@ -193,15 +204,10 @@ class TestAuxFreeAdapters(unittest.TestCase): plugin = AuxFreeMoEPlugin() plugin.post_model_build(cfg, model) - callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) - self.assertTrue(callbacks) - callback = callbacks[0] - dummy = SimpleNamespace() - def _step(): hidden = torch.randn(2, 3, block.config.hidden_size) block(hidden) - callback.on_step_end(args=dummy, state=dummy, control=dummy) + _run_callback(plugin, cfg) # Warmup steps should leave bias untouched. _step() @@ -244,11 +250,6 @@ class TestAuxFreeAdapters(unittest.TestCase): self.assertIsNotNone(plugin._shim) self.assertIsNone(plugin._shim.ep_group) - callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) - self.assertTrue(callbacks) - callback = callbacks[0] - dummy = SimpleNamespace() - tmp_init = tempfile.NamedTemporaryFile(delete=False) tmp_init.close() init_method = f"file://{tmp_init.name}" @@ -256,12 +257,52 @@ class TestAuxFreeAdapters(unittest.TestCase): try: hidden = torch.randn(2, 3, block.config.hidden_size) block(hidden) - callback.on_step_end(args=dummy, state=dummy, control=dummy) + _run_callback( + plugin, + cfg, + args=SimpleNamespace(logging_steps=1), + state=SimpleNamespace(global_step=1, log_history=[]), + control=SimpleNamespace( + should_log=False, + should_evaluate=False, + should_save=False, + should_training_stop=False, + ), + ) self.assertIs(plugin._shim.ep_group, dist.group.WORLD) finally: dist.destroy_process_group() os.unlink(tmp_init.name) + def test_telemetry_logging(self): + model, layer = _build_mixtral_model() + layer.jitter_noise = 0.0 + cfg = _cfg(moe_afb_telemetry_interval=1) + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + hidden_dim = layer.config.hidden_size + hidden = torch.randn(2, 3, hidden_dim) + layer(hidden) + + args = SimpleNamespace(logging_steps=1) + state = SimpleNamespace(global_step=1, log_history=[]) + control = SimpleNamespace( + should_log=False, + should_evaluate=False, + should_save=False, + should_training_stop=False, + ) + _run_callback(plugin, cfg, args=args, state=state, control=control) + + self.assertTrue(control.should_log) + self.assertTrue(state.log_history) + telemetry = state.log_history[-1] + self.assertEqual(telemetry["step"], state.global_step) + self.assertIn("moe_afb/l0_load_min", telemetry) + self.assertIn("moe_afb/l0_load_max", telemetry) + self.assertIn("moe_afb/l0_bias_abs_max", telemetry) + if __name__ == "__main__": unittest.main()