diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 441f8997d..c0fb3c01a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -52,6 +52,7 @@ LOG = get_logger(__name__) TELEMETRY_MANAGER = TelemetryManager.get_instance() + def setup_model_and_tokenizer( cfg: DictDefault, ) -> tuple[ diff --git a/tests/telemetry/test_callbacks.py b/tests/telemetry/test_callbacks.py index 6303812cc..4324126e7 100644 --- a/tests/telemetry/test_callbacks.py +++ b/tests/telemetry/test_callbacks.py @@ -113,7 +113,7 @@ class TestTelemetryCallback: callback.on_train_begin(training_args, trainer_state, trainer_control) mock_telemetry_manager.send_event.assert_called_once_with( - event_type="train-started" + event_type="train-start" ) def test_on_train_end( @@ -130,7 +130,7 @@ class TestTelemetryCallback: mock_telemetry_manager.send_event.assert_called_once() call_args = mock_telemetry_manager.send_event.call_args[1] - assert call_args["event_type"] == "train-ended" + assert call_args["event_type"] == "train-end" assert "loss" in call_args["properties"] assert call_args["properties"]["loss"] == 2.5 assert "learning_rate" in call_args["properties"] diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py index 021d5fbd8..3d00c0f28 100644 --- a/tests/telemetry/test_errors.py +++ b/tests/telemetry/test_errors.py @@ -253,7 +253,7 @@ def test_send_errors_with_exception(mock_telemetry_manager): # Check that the error info was passed correctly call_args = mock_telemetry_manager.send_event.call_args[1] - assert "test_func-errored" in call_args["event_type"] + assert "test_func-error" in call_args["event_type"] assert "Test error" in call_args["properties"]["exception"] assert "stack_trace" in call_args["properties"] @@ -336,5 +336,5 @@ def test_module_path_resolution(mock_telemetry_manager): assert mock_telemetry_manager.send_event.called event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"] - expected_event_type = f"{current_module}.test_func-errored" + expected_event_type = f"{current_module}.test_func-error" assert expected_event_type == event_type