aux_free_router: emit telemetry metrics

- log per-layer load stats, bias magnitude, and oscillation rate

- honor configurable telemetry interval with Trainer logging integration
This commit is contained in:
lhl
2025-10-28 08:27:48 +00:00
committed by Wing Lian
parent 949cdf01eb
commit 6eac9ac372
4 changed files with 147 additions and 19 deletions

View File

@@ -148,12 +148,23 @@ def _build_mixtral_model():
return DummyModel(layer), layer
def _run_callback(plugin, cfg):
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
dummy = SimpleNamespace()
callback.on_step_end(args=dummy, state=dummy, control=dummy)
if args is None:
args = SimpleNamespace(logging_steps=1)
if state is None:
state = SimpleNamespace(global_step=1, log_history=[])
if control is None:
control = SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
)
callback.on_step_end(args=args, state=state, control=control)
return state, control
class TestAuxFreeAdapters(unittest.TestCase):
@@ -193,15 +204,10 @@ class TestAuxFreeAdapters(unittest.TestCase):
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
def _step():
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
_run_callback(plugin, cfg)
# Warmup steps should leave bias untouched.
_step()
@@ -244,11 +250,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
self.assertIsNotNone(plugin._shim)
self.assertIsNone(plugin._shim.ep_group)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
@@ -256,12 +257,52 @@ class TestAuxFreeAdapters(unittest.TestCase):
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
_run_callback(
plugin,
cfg,
args=SimpleNamespace(logging_steps=1),
state=SimpleNamespace(global_step=1, log_history=[]),
control=SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
),
)
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
finally:
dist.destroy_process_group()
os.unlink(tmp_init.name)
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg(moe_afb_telemetry_interval=1)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
layer(hidden)
args = SimpleNamespace(logging_steps=1)
state = SimpleNamespace(global_step=1, log_history=[])
control = SimpleNamespace(
should_log=False,
should_evaluate=False,
should_save=False,
should_training_stop=False,
)
_run_callback(plugin, cfg, args=args, state=state, control=control)
self.assertTrue(control.should_log)
self.assertTrue(state.log_history)
telemetry = state.log_history[-1]
self.assertEqual(telemetry["step"], state.global_step)
self.assertIn("moe_afb/l0_load_min", telemetry)
self.assertIn("moe_afb/l0_load_max", telemetry)
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
if __name__ == "__main__":
unittest.main()