improve: align aux-free telemetry with Trainer logging

This commit is contained in:
lhl
2025-11-11 17:00:48 +00:00
committed by Wing Lian
parent 966a4555db
commit 676d5e855d
6 changed files with 39 additions and 41 deletions

View File

@@ -149,9 +149,6 @@ def _build_mixtral_model():
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
if args is None:
args = SimpleNamespace(logging_steps=1)
if state is None:
@@ -163,6 +160,22 @@ def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
should_save=False,
should_training_stop=False,
)
class DummyTrainer:
def __init__(self, state_obj, control_obj):
self.state = state_obj
self.control = control_obj
def log(self, logs):
output = dict(logs)
output["step"] = self.state.global_step
self.state.log_history.append(output)
self.control.should_log = True
dummy_trainer = DummyTrainer(state, control)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
callback.on_step_end(args=args, state=state, control=control)
return state, control
@@ -277,7 +290,7 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg(moe_afb_telemetry_interval=1)
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)