slight changes
This commit is contained in:
@@ -52,6 +52,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_tokenizer(
|
def setup_model_and_tokenizer(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class TestTelemetryCallback:
|
|||||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
||||||
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
mock_telemetry_manager.send_event.assert_called_once_with(
|
||||||
event_type="train-started"
|
event_type="train-start"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_on_train_end(
|
def test_on_train_end(
|
||||||
@@ -130,7 +130,7 @@ class TestTelemetryCallback:
|
|||||||
mock_telemetry_manager.send_event.assert_called_once()
|
mock_telemetry_manager.send_event.assert_called_once()
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||||
|
|
||||||
assert call_args["event_type"] == "train-ended"
|
assert call_args["event_type"] == "train-end"
|
||||||
assert "loss" in call_args["properties"]
|
assert "loss" in call_args["properties"]
|
||||||
assert call_args["properties"]["loss"] == 2.5
|
assert call_args["properties"]["loss"] == 2.5
|
||||||
assert "learning_rate" in call_args["properties"]
|
assert "learning_rate" in call_args["properties"]
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ def test_send_errors_with_exception(mock_telemetry_manager):
|
|||||||
|
|
||||||
# Check that the error info was passed correctly
|
# Check that the error info was passed correctly
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||||
assert "test_func-errored" in call_args["event_type"]
|
assert "test_func-error" in call_args["event_type"]
|
||||||
assert "Test error" in call_args["properties"]["exception"]
|
assert "Test error" in call_args["properties"]["exception"]
|
||||||
assert "stack_trace" in call_args["properties"]
|
assert "stack_trace" in call_args["properties"]
|
||||||
|
|
||||||
@@ -336,5 +336,5 @@ def test_module_path_resolution(mock_telemetry_manager):
|
|||||||
assert mock_telemetry_manager.send_event.called
|
assert mock_telemetry_manager.send_event.called
|
||||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
||||||
|
|
||||||
expected_event_type = f"{current_module}.test_func-errored"
|
expected_event_type = f"{current_module}.test_func-error"
|
||||||
assert expected_event_type == event_type
|
assert expected_event_type == event_type
|
||||||
|
|||||||
Reference in New Issue
Block a user