aux_free_router: emit telemetry metrics
- log per-layer load stats, bias magnitude, and oscillation rate - honor configurable telemetry interval with Trainer logging integration
This commit is contained in:
@@ -30,6 +30,7 @@ Config keys
|
||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||
- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset).
|
||||
|
||||
Compatibility
|
||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||
@@ -37,4 +38,4 @@ Compatibility
|
||||
|
||||
Notes
|
||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||
- Telemetry: future updates will log per-expert loads and bias magnitudes.
|
||||
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval.
|
||||
|
||||
@@ -17,6 +17,7 @@ class AuxFreeConfig:
|
||||
bias_cap: float = 2.0
|
||||
warmup_steps: int = 0
|
||||
sync_group: str = "world" # or "ep"
|
||||
telemetry_interval: Optional[int] = None
|
||||
|
||||
|
||||
class AuxFreeState:
|
||||
@@ -45,6 +46,7 @@ class AuxFreeShim:
|
||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||
)
|
||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -75,6 +77,9 @@ class AuxFreeShim:
|
||||
"""Call once per optimizer step before per-layer updates."""
|
||||
self.state.steps += 1
|
||||
|
||||
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
|
||||
return self._prev_bias_sign.get(layer_idx)
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||
self._maybe_init_ep_group()
|
||||
@@ -111,6 +116,7 @@ class AuxFreeShim:
|
||||
bias.add_(delta)
|
||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
|
||||
|
||||
def _maybe_init_ep_group(self) -> None:
|
||||
if not self._ep_group_pending:
|
||||
|
||||
@@ -36,9 +36,17 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
routing patches in a follow-up).
|
||||
"""
|
||||
|
||||
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
|
||||
def __init__(
|
||||
self,
|
||||
shim: AuxFreeShim,
|
||||
layer_modules: list[torch.nn.Module],
|
||||
telemetry_interval: Optional[int] = None,
|
||||
):
|
||||
self.shim = shim
|
||||
self.layer_modules = layer_modules
|
||||
self.telemetry_interval = telemetry_interval
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
self._telemetry_buffer: dict[int, dict[str, float]] = {}
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||
# Iterate prepared MoE layers and apply the bias update rule.
|
||||
@@ -59,10 +67,72 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
tokens_seen = int(counts_for_update.sum().item())
|
||||
# local layer-state EMA and bias update
|
||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
|
||||
# reset step counts
|
||||
counts.zero_()
|
||||
|
||||
if self._should_log(args, state):
|
||||
if self._telemetry_buffer:
|
||||
if not hasattr(state, "log_history"):
|
||||
state.log_history = []
|
||||
log_entry = {"step": state.global_step}
|
||||
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||
prefix = f"moe_afb/l{layer_idx}_"
|
||||
for key, value in metrics.items():
|
||||
log_entry[f"{prefix}{key}"] = value
|
||||
state.log_history.append(log_entry)
|
||||
control.should_log = True
|
||||
self._telemetry_buffer.clear()
|
||||
return control
|
||||
|
||||
def _collect_telemetry(
|
||||
self,
|
||||
layer_idx: int,
|
||||
counts: torch.Tensor,
|
||||
tokens_seen: int,
|
||||
bias: torch.Tensor,
|
||||
) -> None:
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
load_min = freq.min().item()
|
||||
load_mean = freq.mean().item()
|
||||
load_max = freq.max().item()
|
||||
bias_abs_max = bias.abs().max().item()
|
||||
|
||||
prev_sign = self._prev_bias_sign.get(layer_idx)
|
||||
current_sign = torch.sign(bias.detach())
|
||||
if prev_sign is None or prev_sign.shape != current_sign.shape:
|
||||
oscillation = 0.0
|
||||
else:
|
||||
changed = (current_sign != prev_sign) & (
|
||||
(current_sign != 0) | (prev_sign != 0)
|
||||
)
|
||||
if changed.numel() == 0:
|
||||
oscillation = 0.0
|
||||
else:
|
||||
oscillation = changed.float().mean().item()
|
||||
self._prev_bias_sign[layer_idx] = current_sign.clone()
|
||||
|
||||
self._telemetry_buffer[layer_idx] = {
|
||||
"load_min": load_min,
|
||||
"load_mean": load_mean,
|
||||
"load_max": load_max,
|
||||
"bias_abs_max": bias_abs_max,
|
||||
"bias_sign_flip_frac": oscillation,
|
||||
}
|
||||
|
||||
def _should_log(self, args, state) -> bool:
|
||||
interval = (
|
||||
self.telemetry_interval
|
||||
if self.telemetry_interval is not None and self.telemetry_interval > 0
|
||||
else getattr(args, "logging_steps", 0)
|
||||
)
|
||||
interval = max(1, int(interval)) if interval else 0
|
||||
if interval == 0:
|
||||
return False
|
||||
return state.global_step % interval == 0
|
||||
|
||||
|
||||
class AuxFreeMoEPlugin(BasePlugin):
|
||||
"""Plugin that enables aux-loss-free routing when configured."""
|
||||
@@ -95,8 +165,14 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||
telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None)
|
||||
af_cfg = AuxFreeConfig(
|
||||
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
|
||||
rate=rate,
|
||||
momentum=momentum,
|
||||
bias_cap=bias_cap,
|
||||
warmup_steps=warmup,
|
||||
sync_group=sync_group,
|
||||
telemetry_interval=telemetry_interval,
|
||||
)
|
||||
|
||||
# Discover layers to count the number and experts for state sizing
|
||||
@@ -170,6 +246,10 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
return []
|
||||
# gather concrete layer modules from handles
|
||||
layers = [h.layer for h in self._handles]
|
||||
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
|
||||
cb = MoeAuxFreeBiasUpdateCallback(
|
||||
self._shim,
|
||||
layers,
|
||||
telemetry_interval=self._shim.state.cfg.telemetry_interval,
|
||||
)
|
||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||
return [cb]
|
||||
|
||||
@@ -148,12 +148,23 @@ def _build_mixtral_model():
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg):
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
if state is None:
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
if control is None:
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
return state, control
|
||||
|
||||
|
||||
class TestAuxFreeAdapters(unittest.TestCase):
|
||||
@@ -193,15 +204,10 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
def _step():
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
@@ -244,11 +250,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
self.assertIsNotNone(plugin._shim)
|
||||
self.assertIsNone(plugin._shim.ep_group)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
@@ -256,12 +257,52 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
_run_callback(
|
||||
plugin,
|
||||
cfg,
|
||||
args=SimpleNamespace(logging_steps=1),
|
||||
state=SimpleNamespace(global_step=1, log_history=[]),
|
||||
control=SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
),
|
||||
)
|
||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
os.unlink(tmp_init.name)
|
||||
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0
|
||||
cfg = _cfg(moe_afb_telemetry_interval=1)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
layer(hidden)
|
||||
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
||||
|
||||
self.assertTrue(control.should_log)
|
||||
self.assertTrue(state.log_history)
|
||||
telemetry = state.log_history[-1]
|
||||
self.assertEqual(telemetry["step"], state.global_step)
|
||||
self.assertIn("moe_afb/l0_load_min", telemetry)
|
||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user