aux_free_router: emit telemetry metrics

- log per-layer load stats, bias magnitude, and oscillation rate

- honor configurable telemetry interval with Trainer logging integration
This commit is contained in:
lhl
2025-10-28 08:27:48 +00:00
committed by Wing Lian
parent 949cdf01eb
commit 6eac9ac372
4 changed files with 147 additions and 19 deletions

View File

@@ -30,6 +30,7 @@ Config keys
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset).
Compatibility
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
@@ -37,4 +38,4 @@ Compatibility
Notes
- If you also enable Ligers aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
- Telemetry: future updates will log per-expert loads and bias magnitudes.
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval.

View File

@@ -17,6 +17,7 @@ class AuxFreeConfig:
bias_cap: float = 2.0
warmup_steps: int = 0
sync_group: str = "world" # or "ep"
telemetry_interval: Optional[int] = None
class AuxFreeState:
@@ -45,6 +46,7 @@ class AuxFreeShim:
self.state.cfg.sync_group == "ep" and self.ep_group is None
)
self._layer_modules: dict[int, torch.nn.Module] = {}
self._prev_bias_sign: dict[int, torch.Tensor] = {}
@torch.no_grad()
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -75,6 +77,9 @@ class AuxFreeShim:
"""Call once per optimizer step before per-layer updates."""
self.state.steps += 1
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
return self._prev_bias_sign.get(layer_idx)
@torch.no_grad()
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
self._maybe_init_ep_group()
@@ -111,6 +116,7 @@ class AuxFreeShim:
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
def _maybe_init_ep_group(self) -> None:
if not self._ep_group_pending:

View File

@@ -36,9 +36,17 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
routing patches in a follow-up).
"""
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
def __init__(
self,
shim: AuxFreeShim,
layer_modules: list[torch.nn.Module],
telemetry_interval: Optional[int] = None,
):
self.shim = shim
self.layer_modules = layer_modules
self.telemetry_interval = telemetry_interval
self._prev_bias_sign: dict[int, torch.Tensor] = {}
self._telemetry_buffer: dict[int, dict[str, float]] = {}
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
# Iterate prepared MoE layers and apply the bias update rule.
@@ -59,10 +67,72 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
# reset step counts
counts.zero_()
if self._should_log(args, state):
if self._telemetry_buffer:
if not hasattr(state, "log_history"):
state.log_history = []
log_entry = {"step": state.global_step}
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
prefix = f"moe_afb/l{layer_idx}_"
for key, value in metrics.items():
log_entry[f"{prefix}{key}"] = value
state.log_history.append(log_entry)
control.should_log = True
self._telemetry_buffer.clear()
return control
def _collect_telemetry(
self,
layer_idx: int,
counts: torch.Tensor,
tokens_seen: int,
bias: torch.Tensor,
) -> None:
if tokens_seen <= 0:
return
freq = counts.float() / float(tokens_seen)
load_min = freq.min().item()
load_mean = freq.mean().item()
load_max = freq.max().item()
bias_abs_max = bias.abs().max().item()
prev_sign = self._prev_bias_sign.get(layer_idx)
current_sign = torch.sign(bias.detach())
if prev_sign is None or prev_sign.shape != current_sign.shape:
oscillation = 0.0
else:
changed = (current_sign != prev_sign) & (
(current_sign != 0) | (prev_sign != 0)
)
if changed.numel() == 0:
oscillation = 0.0
else:
oscillation = changed.float().mean().item()
self._prev_bias_sign[layer_idx] = current_sign.clone()
self._telemetry_buffer[layer_idx] = {
"load_min": load_min,
"load_mean": load_mean,
"load_max": load_max,
"bias_abs_max": bias_abs_max,
"bias_sign_flip_frac": oscillation,
}
def _should_log(self, args, state) -> bool:
interval = (
self.telemetry_interval
if self.telemetry_interval is not None and self.telemetry_interval > 0
else getattr(args, "logging_steps", 0)
)
interval = max(1, int(interval)) if interval else 0
if interval == 0:
return False
return state.global_step % interval == 0
class AuxFreeMoEPlugin(BasePlugin):
"""Plugin that enables aux-loss-free routing when configured."""
@@ -95,8 +165,14 @@ class AuxFreeMoEPlugin(BasePlugin):
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None)
af_cfg = AuxFreeConfig(
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
rate=rate,
momentum=momentum,
bias_cap=bias_cap,
warmup_steps=warmup,
sync_group=sync_group,
telemetry_interval=telemetry_interval,
)
# Discover layers to count the number and experts for state sizing
@@ -170,6 +246,10 @@ class AuxFreeMoEPlugin(BasePlugin):
return []
# gather concrete layer modules from handles
layers = [h.layer for h in self._handles]
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
cb = MoeAuxFreeBiasUpdateCallback(
self._shim,
layers,
telemetry_interval=self._shim.state.cfg.telemetry_interval,
)
LOG.info("AuxFreeMoE: registering post-step bias update callback")
return [cb]

View File

@@ -148,12 +148,23 @@ def _build_mixtral_model():
return DummyModel(layer), layer
def _run_callback(plugin, cfg):
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
dummy = SimpleNamespace()
callback.on_step_end(args=dummy, state=dummy, control=dummy)
if args is None:
args = SimpleNamespace(logging_steps=1)
if state is None:
state = SimpleNamespace(global_step=1, log_history=[])
if control is None:
control = SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
)
callback.on_step_end(args=args, state=state, control=control)
return state, control
class TestAuxFreeAdapters(unittest.TestCase):
@@ -193,15 +204,10 @@ class TestAuxFreeAdapters(unittest.TestCase):
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
def _step():
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
_run_callback(plugin, cfg)
# Warmup steps should leave bias untouched.
_step()
@@ -244,11 +250,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
self.assertIsNotNone(plugin._shim)
self.assertIsNone(plugin._shim.ep_group)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
@@ -256,12 +257,52 @@ class TestAuxFreeAdapters(unittest.TestCase):
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
_run_callback(
plugin,
cfg,
args=SimpleNamespace(logging_steps=1),
state=SimpleNamespace(global_step=1, log_history=[]),
control=SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
),
)
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
finally:
dist.destroy_process_group()
os.unlink(tmp_init.name)
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg(moe_afb_telemetry_interval=1)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
layer(hidden)
args = SimpleNamespace(logging_steps=1)
state = SimpleNamespace(global_step=1, log_history=[])
control = SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
)
_run_callback(plugin, cfg, args=args, state=state, control=control)
self.assertTrue(control.should_log)
self.assertTrue(state.log_history)
telemetry = state.log_history[-1]
self.assertEqual(telemetry["step"], state.global_step)
self.assertIn("moe_afb/l0_load_min", telemetry)
self.assertIn("moe_afb/l0_load_max", telemetry)
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
if __name__ == "__main__":
unittest.main()