This commit is contained in:
Dan Saunders
2025-02-19 13:55:04 +00:00
parent 07ffd47f2b
commit c50610375f
4 changed files with 34 additions and 26 deletions

View File

@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
@pytest.fixture
@@ -12,7 +12,7 @@ def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing"""
whitelist_content = {
"organizations": ["meta", "mistral"],
"models": ["llama", "mistral-7b"]
"models": ["llama", "mistral-7b"],
}
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w") as f:
@@ -24,8 +24,7 @@ def mock_whitelist(tmp_path):
def config(mock_whitelist):
"""Create a TelemetryConfig with test settings"""
return TelemetryConfig(
host="https://test.posthog.com",
whitelist_path=mock_whitelist
host="https://test.posthog.com", whitelist_path=mock_whitelist
)
@@ -51,10 +50,7 @@ def test_telemetry_opt_in():
def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, {
"AXOLOTL_TELEMETRY": "1",
"DO_NOT_TRACK": "1"
}):
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
manager = TelemetryManager(TelemetryConfig())
assert not manager.enabled
@@ -77,7 +73,7 @@ def test_event_tracking(manager):
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"]["key"] == "value"
@@ -97,28 +93,31 @@ def test_model_tracking(manager):
tokenizer_config={},
flash_attention=True,
quantization_config=None,
training_approach="lora"
training_approach="lora",
)
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_model_load(model_config)
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "model_load"
assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict()
assert (
mock_capture.call_args[1]["properties"]["model_config"]
== model_config.to_dict()
)
def test_training_context(manager):
"""Test training context manager"""
config = {"model": "llama", "batch_size": 8}
with patch("posthog.capture") as mock_capture:
manager.enabled = True
with manager.track_training(config):
pass # Simulate successful training
# Should have captured training_start and training_complete
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events
@@ -128,14 +127,14 @@ def test_training_context(manager):
def test_training_error(manager):
"""Test training context manager with error"""
config = {"model": "llama", "batch_size": 8}
with patch("posthog.capture") as mock_capture:
manager.enabled = True
with pytest.raises(ValueError):
with manager.track_training(config):
raise ValueError("Test error")
# Should have captured training_start and training_error
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events