progress on telemetry: config load, process, model load, train start / end, error tracking
This commit is contained in:
@@ -1,71 +1,65 @@
|
||||
"""Tests for TelemetryManager class and utilities"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
|
||||
from axolotl.telemetry import TelemetryManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_whitelist(tmp_path):
|
||||
"""Create a temporary whitelist file for testing"""
|
||||
whitelist_content = {
|
||||
"organizations": ["meta", "mistral"],
|
||||
"models": ["llama", "mistral-7b"],
|
||||
"organizations": ["meta-llama", "mistralai"],
|
||||
}
|
||||
whitelist_file = tmp_path / "whitelist.yaml"
|
||||
with open(whitelist_file, "w") as f:
|
||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(whitelist_content, f)
|
||||
return str(whitelist_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(mock_whitelist):
|
||||
"""Create a TelemetryConfig with test settings"""
|
||||
return TelemetryConfig(
|
||||
host="https://test.posthog.com", whitelist_path=mock_whitelist
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(config):
|
||||
def manager():
|
||||
"""Create a TelemetryManager instance with mocked PostHog"""
|
||||
with patch("posthog.capture"):
|
||||
return TelemetryManager(config)
|
||||
return TelemetryManager()
|
||||
|
||||
|
||||
def test_telemetry_disabled_by_default():
|
||||
"""Test that telemetry is disabled by default"""
|
||||
manager = TelemetryManager(TelemetryConfig())
|
||||
manager = TelemetryManager()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_opt_in():
|
||||
"""Test that telemetry can be enabled via environment variable"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||
manager = TelemetryManager(TelemetryConfig())
|
||||
manager = TelemetryManager()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_do_not_track_override():
|
||||
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
||||
manager = TelemetryManager(TelemetryConfig())
|
||||
manager = TelemetryManager()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_whitelist_checking(manager):
|
||||
"""Test model whitelist functionality"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||
# Should match organization
|
||||
assert manager._is_whitelisted("meta/llama-7b")
|
||||
# Should match model name
|
||||
assert manager._is_whitelisted("mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("meta/Llama-7b")
|
||||
# Should match organization
|
||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||
# Should match model name
|
||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("meta/Llama-7b")
|
||||
|
||||
|
||||
def test_event_tracking(manager):
|
||||
@@ -81,33 +75,6 @@ def test_event_tracking(manager):
|
||||
assert "system_info" in mock_capture.call_args[1]["properties"]
|
||||
|
||||
|
||||
def test_model_tracking(manager):
|
||||
"""Test model load tracking"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||
model_config = ModelConfig(
|
||||
base_model="meta/llama-7b",
|
||||
model_type="decoder",
|
||||
hidden_size=4096,
|
||||
num_layers=32,
|
||||
num_attention_heads=32,
|
||||
tokenizer_config={},
|
||||
flash_attention=True,
|
||||
quantization_config=None,
|
||||
training_approach="lora",
|
||||
)
|
||||
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.enabled = True
|
||||
manager.track_model_load(model_config)
|
||||
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "model_load"
|
||||
assert (
|
||||
mock_capture.call_args[1]["properties"]["model_config"]
|
||||
== model_config.to_dict()
|
||||
)
|
||||
|
||||
|
||||
def test_training_context(manager):
|
||||
"""Test training context manager"""
|
||||
config = {"model": "llama", "batch_size": 8}
|
||||
@@ -141,6 +108,7 @@ def test_training_error(manager):
|
||||
assert "training_error" in events
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_path_sanitization(manager):
|
||||
"""Test path sanitization"""
|
||||
path = "/home/user/sensitive/data.txt"
|
||||
@@ -149,6 +117,7 @@ def test_path_sanitization(manager):
|
||||
assert "/home/user" not in sanitized
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_error_sanitization(manager):
|
||||
"""Test error message sanitization"""
|
||||
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
||||
|
||||
Reference in New Issue
Block a user