improve: align aux-free telemetry with Trainer logging
This commit is contained in:
@@ -149,9 +149,6 @@ def _build_mixtral_model():
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
if state is None:
|
||||
@@ -163,6 +160,22 @@ def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
|
||||
class DummyTrainer:
|
||||
def __init__(self, state_obj, control_obj):
|
||||
self.state = state_obj
|
||||
self.control = control_obj
|
||||
|
||||
def log(self, logs):
|
||||
output = dict(logs)
|
||||
output["step"] = self.state.global_step
|
||||
self.state.log_history.append(output)
|
||||
self.control.should_log = True
|
||||
|
||||
dummy_trainer = DummyTrainer(state, control)
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
return state, control
|
||||
|
||||
@@ -277,7 +290,7 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0
|
||||
cfg = _cfg(moe_afb_telemetry_interval=1)
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user