aux_free_router: emit telemetry metrics
- log per-layer load stats, bias magnitude, and oscillation rate - honor configurable telemetry interval with Trainer logging integration
This commit is contained in:
@@ -148,12 +148,23 @@ def _build_mixtral_model():
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg):
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
if state is None:
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
if control is None:
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
return state, control
|
||||
|
||||
|
||||
class TestAuxFreeAdapters(unittest.TestCase):
|
||||
@@ -193,15 +204,10 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
def _step():
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
@@ -244,11 +250,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
self.assertIsNotNone(plugin._shim)
|
||||
self.assertIsNone(plugin._shim.ep_group)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
@@ -256,12 +257,52 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
_run_callback(
|
||||
plugin,
|
||||
cfg,
|
||||
args=SimpleNamespace(logging_steps=1),
|
||||
state=SimpleNamespace(global_step=1, log_history=[]),
|
||||
control=SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
),
|
||||
)
|
||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
os.unlink(tmp_init.name)
|
||||
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0
|
||||
cfg = _cfg(moe_afb_telemetry_interval=1)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
layer(hidden)
|
||||
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
||||
|
||||
self.assertTrue(control.should_log)
|
||||
self.assertTrue(state.log_history)
|
||||
telemetry = state.log_history[-1]
|
||||
self.assertEqual(telemetry["step"], state.global_step)
|
||||
self.assertIn("moe_afb/l0_load_min", telemetry)
|
||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user