diff --git a/src/axolotl/integrations/aux_free_router/README.md b/src/axolotl/integrations/aux_free_router/README.md index a84ec17cd..493c952fa 100644 --- a/src/axolotl/integrations/aux_free_router/README.md +++ b/src/axolotl/integrations/aux_free_router/README.md @@ -30,7 +30,6 @@ Config keys - moe_afb_warmup_steps: delay before applying updates. Default: 0. - moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world. - expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world). -- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset). Compatibility - Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers. @@ -38,7 +37,7 @@ Compatibility Notes - If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on. -- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval. +- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainer’s `logging_steps` cadence. - Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert. Telemetry metrics @@ -47,5 +46,5 @@ Telemetry metrics - `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator). Usage tips -- Leave `moe_afb_telemetry_interval` unset to log on the Trainer’s `logging_steps`. Increase the interval for large jobs to reduce log volume. +- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainer’s logging cadence. - Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term. diff --git a/src/axolotl/integrations/aux_free_router/core.py b/src/axolotl/integrations/aux_free_router/core.py index 9012c4844..30b180547 100644 --- a/src/axolotl/integrations/aux_free_router/core.py +++ b/src/axolotl/integrations/aux_free_router/core.py @@ -17,7 +17,6 @@ class AuxFreeConfig: bias_cap: float = 2.0 warmup_steps: int = 0 sync_group: str = "world" # or "ep" - telemetry_interval: Optional[int] = None class AuxFreeState: diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index 3dc456222..4f026ed1c 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection. from __future__ import annotations -from typing import Optional +from typing import Optional, Any import torch import torch.distributed as dist @@ -29,22 +29,17 @@ LOG = get_logger(__name__) class MoeAuxFreeBiasUpdateCallback(TrainerCallback): - """Post-step callback to update aux-free biases from accumulated expert counts. - - Note: The current revision expects per-layer counts to be accumulated on each - MoE layer as a buffer named `_afb_counts` during forward (to be added with - routing patches in a follow-up). - """ + """Post-step callback to update aux-free biases from accumulated expert counts.""" def __init__( self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module], - telemetry_interval: Optional[int] = None, + trainer: Any, ): self.shim = shim self.layer_modules = layer_modules - self.telemetry_interval = telemetry_interval + self.trainer = trainer self._prev_bias_sign: dict[int, torch.Tensor] = {} self._telemetry_buffer: dict[int, dict[str, float]] = {} @@ -71,17 +66,14 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): # reset step counts counts.zero_() - if self._should_log(args, state): - if self._telemetry_buffer: - if not hasattr(state, "log_history"): - state.log_history = [] - log_entry = {"step": state.global_step} - for layer_idx, metrics in sorted(self._telemetry_buffer.items()): - prefix = f"moe_afb/l{layer_idx}_" - for key, value in metrics.items(): - log_entry[f"{prefix}{key}"] = value - state.log_history.append(log_entry) - control.should_log = True + if self._should_log(args, state) and self._telemetry_buffer: + logs: dict[str, float] = {} + for layer_idx, metrics in sorted(self._telemetry_buffer.items()): + prefix = f"moe_afb/l{layer_idx}_" + for key, value in metrics.items(): + logs[f"{prefix}{key}"] = value + if logs and hasattr(self.trainer, "log"): + self.trainer.log(logs) self._telemetry_buffer.clear() return control @@ -123,15 +115,14 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback): } def _should_log(self, args, state) -> bool: - interval = ( - self.telemetry_interval - if self.telemetry_interval is not None and self.telemetry_interval > 0 - else getattr(args, "logging_steps", 0) - ) - interval = max(1, int(interval)) if interval else 0 - if interval == 0: + interval = getattr(args, "logging_steps", 0) + if not interval: return False - return state.global_step % interval == 0 + try: + interval = max(1, int(interval)) + except (TypeError, ValueError): + return False + return interval > 0 and state.global_step % interval == 0 class AuxFreeMoEPlugin(BasePlugin): @@ -165,14 +156,12 @@ class AuxFreeMoEPlugin(BasePlugin): bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0 warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0 sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world" - telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None) af_cfg = AuxFreeConfig( rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group, - telemetry_interval=telemetry_interval, ) # Discover layers to count the number and experts for state sizing @@ -249,7 +238,7 @@ class AuxFreeMoEPlugin(BasePlugin): cb = MoeAuxFreeBiasUpdateCallback( self._shim, layers, - telemetry_interval=self._shim.state.cfg.telemetry_interval, + trainer, ) LOG.info("AuxFreeMoE: registering post-step bias update callback") return [cb] diff --git a/tests/e2e/test_llama4_moe_aux_free.py b/tests/e2e/test_llama4_moe_aux_free.py index bc5341dd9..bceb55816 100644 --- a/tests/e2e/test_llama4_moe_aux_free.py +++ b/tests/e2e/test_llama4_moe_aux_free.py @@ -52,7 +52,6 @@ class TestLlama4MoeAuxFree(unittest.TestCase): "moe_update_rate": 0.01, "moe_update_momentum": 0.9, "moe_bias_cap": 2.0, - "moe_afb_telemetry_interval": 1, } ) diff --git a/tests/e2e/test_ring_moe_aux_free.py b/tests/e2e/test_ring_moe_aux_free.py index 1905582d7..992f703e4 100644 --- a/tests/e2e/test_ring_moe_aux_free.py +++ b/tests/e2e/test_ring_moe_aux_free.py @@ -53,7 +53,6 @@ class TestRingMoeAuxFree(unittest.TestCase): "moe_update_rate": 0.01, "moe_update_momentum": 0.9, "moe_bias_cap": 2.0, - "moe_afb_telemetry_interval": 1, } ) diff --git a/tests/unit/test_aux_free_adapters.py b/tests/unit/test_aux_free_adapters.py index b7889fcfa..3bc3ac8e5 100644 --- a/tests/unit/test_aux_free_adapters.py +++ b/tests/unit/test_aux_free_adapters.py @@ -149,9 +149,6 @@ def _build_mixtral_model(): def _run_callback(plugin, cfg, *, args=None, state=None, control=None): - callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) - assert callbacks, "expected aux-free callback to be registered" - callback = callbacks[0] if args is None: args = SimpleNamespace(logging_steps=1) if state is None: @@ -163,6 +160,22 @@ def _run_callback(plugin, cfg, *, args=None, state=None, control=None): should_save=False, should_training_stop=False, ) + + class DummyTrainer: + def __init__(self, state_obj, control_obj): + self.state = state_obj + self.control = control_obj + + def log(self, logs): + output = dict(logs) + output["step"] = self.state.global_step + self.state.log_history.append(output) + self.control.should_log = True + + dummy_trainer = DummyTrainer(state, control) + callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer) + assert callbacks, "expected aux-free callback to be registered" + callback = callbacks[0] callback.on_step_end(args=args, state=state, control=control) return state, control @@ -277,7 +290,7 @@ class TestAuxFreeAdapters(unittest.TestCase): def test_telemetry_logging(self): model, layer = _build_mixtral_model() layer.jitter_noise = 0.0 - cfg = _cfg(moe_afb_telemetry_interval=1) + cfg = _cfg() plugin = AuxFreeMoEPlugin() plugin.post_model_build(cfg, model)