aux_free_router: emit telemetry metrics
- log per-layer load stats, bias magnitude, and oscillation rate - honor configurable telemetry interval with Trainer logging integration
This commit is contained in:
@@ -30,6 +30,7 @@ Config keys
|
|||||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||||
|
- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset).
|
||||||
|
|
||||||
Compatibility
|
Compatibility
|
||||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||||
@@ -37,4 +38,4 @@ Compatibility
|
|||||||
|
|
||||||
Notes
|
Notes
|
||||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||||
- Telemetry: future updates will log per-expert loads and bias magnitudes.
|
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class AuxFreeConfig:
|
|||||||
bias_cap: float = 2.0
|
bias_cap: float = 2.0
|
||||||
warmup_steps: int = 0
|
warmup_steps: int = 0
|
||||||
sync_group: str = "world" # or "ep"
|
sync_group: str = "world" # or "ep"
|
||||||
|
telemetry_interval: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class AuxFreeState:
|
class AuxFreeState:
|
||||||
@@ -45,6 +46,7 @@ class AuxFreeShim:
|
|||||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||||
)
|
)
|
||||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||||
|
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
@@ -75,6 +77,9 @@ class AuxFreeShim:
|
|||||||
"""Call once per optimizer step before per-layer updates."""
|
"""Call once per optimizer step before per-layer updates."""
|
||||||
self.state.steps += 1
|
self.state.steps += 1
|
||||||
|
|
||||||
|
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
|
||||||
|
return self._prev_bias_sign.get(layer_idx)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||||
self._maybe_init_ep_group()
|
self._maybe_init_ep_group()
|
||||||
@@ -111,6 +116,7 @@ class AuxFreeShim:
|
|||||||
bias.add_(delta)
|
bias.add_(delta)
|
||||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||||
|
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
|
||||||
|
|
||||||
def _maybe_init_ep_group(self) -> None:
|
def _maybe_init_ep_group(self) -> None:
|
||||||
if not self._ep_group_pending:
|
if not self._ep_group_pending:
|
||||||
|
|||||||
@@ -36,9 +36,17 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
|||||||
routing patches in a follow-up).
|
routing patches in a follow-up).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
shim: AuxFreeShim,
|
||||||
|
layer_modules: list[torch.nn.Module],
|
||||||
|
telemetry_interval: Optional[int] = None,
|
||||||
|
):
|
||||||
self.shim = shim
|
self.shim = shim
|
||||||
self.layer_modules = layer_modules
|
self.layer_modules = layer_modules
|
||||||
|
self.telemetry_interval = telemetry_interval
|
||||||
|
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||||
|
self._telemetry_buffer: dict[int, dict[str, float]] = {}
|
||||||
|
|
||||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||||
# Iterate prepared MoE layers and apply the bias update rule.
|
# Iterate prepared MoE layers and apply the bias update rule.
|
||||||
@@ -59,10 +67,72 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
|||||||
tokens_seen = int(counts_for_update.sum().item())
|
tokens_seen = int(counts_for_update.sum().item())
|
||||||
# local layer-state EMA and bias update
|
# local layer-state EMA and bias update
|
||||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||||
|
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
|
||||||
# reset step counts
|
# reset step counts
|
||||||
counts.zero_()
|
counts.zero_()
|
||||||
|
|
||||||
|
if self._should_log(args, state):
|
||||||
|
if self._telemetry_buffer:
|
||||||
|
if not hasattr(state, "log_history"):
|
||||||
|
state.log_history = []
|
||||||
|
log_entry = {"step": state.global_step}
|
||||||
|
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||||
|
prefix = f"moe_afb/l{layer_idx}_"
|
||||||
|
for key, value in metrics.items():
|
||||||
|
log_entry[f"{prefix}{key}"] = value
|
||||||
|
state.log_history.append(log_entry)
|
||||||
|
control.should_log = True
|
||||||
|
self._telemetry_buffer.clear()
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def _collect_telemetry(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
counts: torch.Tensor,
|
||||||
|
tokens_seen: int,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if tokens_seen <= 0:
|
||||||
|
return
|
||||||
|
freq = counts.float() / float(tokens_seen)
|
||||||
|
load_min = freq.min().item()
|
||||||
|
load_mean = freq.mean().item()
|
||||||
|
load_max = freq.max().item()
|
||||||
|
bias_abs_max = bias.abs().max().item()
|
||||||
|
|
||||||
|
prev_sign = self._prev_bias_sign.get(layer_idx)
|
||||||
|
current_sign = torch.sign(bias.detach())
|
||||||
|
if prev_sign is None or prev_sign.shape != current_sign.shape:
|
||||||
|
oscillation = 0.0
|
||||||
|
else:
|
||||||
|
changed = (current_sign != prev_sign) & (
|
||||||
|
(current_sign != 0) | (prev_sign != 0)
|
||||||
|
)
|
||||||
|
if changed.numel() == 0:
|
||||||
|
oscillation = 0.0
|
||||||
|
else:
|
||||||
|
oscillation = changed.float().mean().item()
|
||||||
|
self._prev_bias_sign[layer_idx] = current_sign.clone()
|
||||||
|
|
||||||
|
self._telemetry_buffer[layer_idx] = {
|
||||||
|
"load_min": load_min,
|
||||||
|
"load_mean": load_mean,
|
||||||
|
"load_max": load_max,
|
||||||
|
"bias_abs_max": bias_abs_max,
|
||||||
|
"bias_sign_flip_frac": oscillation,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _should_log(self, args, state) -> bool:
|
||||||
|
interval = (
|
||||||
|
self.telemetry_interval
|
||||||
|
if self.telemetry_interval is not None and self.telemetry_interval > 0
|
||||||
|
else getattr(args, "logging_steps", 0)
|
||||||
|
)
|
||||||
|
interval = max(1, int(interval)) if interval else 0
|
||||||
|
if interval == 0:
|
||||||
|
return False
|
||||||
|
return state.global_step % interval == 0
|
||||||
|
|
||||||
|
|
||||||
class AuxFreeMoEPlugin(BasePlugin):
|
class AuxFreeMoEPlugin(BasePlugin):
|
||||||
"""Plugin that enables aux-loss-free routing when configured."""
|
"""Plugin that enables aux-loss-free routing when configured."""
|
||||||
@@ -95,8 +165,14 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||||
|
telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None)
|
||||||
af_cfg = AuxFreeConfig(
|
af_cfg = AuxFreeConfig(
|
||||||
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
|
rate=rate,
|
||||||
|
momentum=momentum,
|
||||||
|
bias_cap=bias_cap,
|
||||||
|
warmup_steps=warmup,
|
||||||
|
sync_group=sync_group,
|
||||||
|
telemetry_interval=telemetry_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Discover layers to count the number and experts for state sizing
|
# Discover layers to count the number and experts for state sizing
|
||||||
@@ -170,6 +246,10 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
return []
|
return []
|
||||||
# gather concrete layer modules from handles
|
# gather concrete layer modules from handles
|
||||||
layers = [h.layer for h in self._handles]
|
layers = [h.layer for h in self._handles]
|
||||||
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
|
cb = MoeAuxFreeBiasUpdateCallback(
|
||||||
|
self._shim,
|
||||||
|
layers,
|
||||||
|
telemetry_interval=self._shim.state.cfg.telemetry_interval,
|
||||||
|
)
|
||||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||||
return [cb]
|
return [cb]
|
||||||
|
|||||||
@@ -148,12 +148,23 @@ def _build_mixtral_model():
|
|||||||
return DummyModel(layer), layer
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
def _run_callback(plugin, cfg):
|
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||||
assert callbacks, "expected aux-free callback to be registered"
|
assert callbacks, "expected aux-free callback to be registered"
|
||||||
callback = callbacks[0]
|
callback = callbacks[0]
|
||||||
dummy = SimpleNamespace()
|
if args is None:
|
||||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
args = SimpleNamespace(logging_steps=1)
|
||||||
|
if state is None:
|
||||||
|
state = SimpleNamespace(global_step=1, log_history=[])
|
||||||
|
if control is None:
|
||||||
|
control = SimpleNamespace(
|
||||||
|
should_log=False,
|
||||||
|
should_evaluate=False,
|
||||||
|
should_save=False,
|
||||||
|
should_training_stop=False,
|
||||||
|
)
|
||||||
|
callback.on_step_end(args=args, state=state, control=control)
|
||||||
|
return state, control
|
||||||
|
|
||||||
|
|
||||||
class TestAuxFreeAdapters(unittest.TestCase):
|
class TestAuxFreeAdapters(unittest.TestCase):
|
||||||
@@ -193,15 +204,10 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
plugin = AuxFreeMoEPlugin()
|
plugin = AuxFreeMoEPlugin()
|
||||||
plugin.post_model_build(cfg, model)
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
||||||
self.assertTrue(callbacks)
|
|
||||||
callback = callbacks[0]
|
|
||||||
dummy = SimpleNamespace()
|
|
||||||
|
|
||||||
def _step():
|
def _step():
|
||||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
block(hidden)
|
block(hidden)
|
||||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
# Warmup steps should leave bias untouched.
|
# Warmup steps should leave bias untouched.
|
||||||
_step()
|
_step()
|
||||||
@@ -244,11 +250,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
self.assertIsNotNone(plugin._shim)
|
self.assertIsNotNone(plugin._shim)
|
||||||
self.assertIsNone(plugin._shim.ep_group)
|
self.assertIsNone(plugin._shim.ep_group)
|
||||||
|
|
||||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
||||||
self.assertTrue(callbacks)
|
|
||||||
callback = callbacks[0]
|
|
||||||
dummy = SimpleNamespace()
|
|
||||||
|
|
||||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||||
tmp_init.close()
|
tmp_init.close()
|
||||||
init_method = f"file://{tmp_init.name}"
|
init_method = f"file://{tmp_init.name}"
|
||||||
@@ -256,12 +257,52 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
try:
|
try:
|
||||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
block(hidden)
|
block(hidden)
|
||||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
_run_callback(
|
||||||
|
plugin,
|
||||||
|
cfg,
|
||||||
|
args=SimpleNamespace(logging_steps=1),
|
||||||
|
state=SimpleNamespace(global_step=1, log_history=[]),
|
||||||
|
control=SimpleNamespace(
|
||||||
|
should_log=False,
|
||||||
|
should_evaluate=False,
|
||||||
|
should_save=False,
|
||||||
|
should_training_stop=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||||
finally:
|
finally:
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
os.unlink(tmp_init.name)
|
os.unlink(tmp_init.name)
|
||||||
|
|
||||||
|
def test_telemetry_logging(self):
|
||||||
|
model, layer = _build_mixtral_model()
|
||||||
|
layer.jitter_noise = 0.0
|
||||||
|
cfg = _cfg(moe_afb_telemetry_interval=1)
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
hidden_dim = layer.config.hidden_size
|
||||||
|
hidden = torch.randn(2, 3, hidden_dim)
|
||||||
|
layer(hidden)
|
||||||
|
|
||||||
|
args = SimpleNamespace(logging_steps=1)
|
||||||
|
state = SimpleNamespace(global_step=1, log_history=[])
|
||||||
|
control = SimpleNamespace(
|
||||||
|
should_log=False,
|
||||||
|
should_evaluate=False,
|
||||||
|
should_save=False,
|
||||||
|
should_training_stop=False,
|
||||||
|
)
|
||||||
|
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
||||||
|
|
||||||
|
self.assertTrue(control.should_log)
|
||||||
|
self.assertTrue(state.log_history)
|
||||||
|
telemetry = state.log_history[-1]
|
||||||
|
self.assertEqual(telemetry["step"], state.global_step)
|
||||||
|
self.assertIn("moe_afb/l0_load_min", telemetry)
|
||||||
|
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||||
|
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user