This commit is contained in:
Dan Saunders
2025-02-19 13:55:04 +00:00
parent 07ffd47f2b
commit c50610375f
4 changed files with 34 additions and 26 deletions

View File

@@ -1,5 +1,15 @@
"""Init for axolotl.telemetry module."""
from .manager import ModelConfig, TelemetryConfig, TelemetryManager, init_telemetry_manager
from .manager import (
ModelConfig,
TelemetryConfig,
TelemetryManager,
init_telemetry_manager,
)
__all__ = ["TelemetryConfig", "TelemetryManager", "ModelConfig", "init_telemetry_manager"]
__all__ = [
"TelemetryConfig",
"TelemetryManager",
"ModelConfig",
"init_telemetry_manager",
]

View File

@@ -126,8 +126,7 @@ class TelemetryManager:
base_model = base_model.lower()
return any(
org.lower() in base_model
for org in self.whitelist.get("organizations", [])
org.lower() in base_model for org in self.whitelist.get("organizations", [])
)
def _init_posthog(self):
@@ -192,7 +191,7 @@ class TelemetryManager:
"run_id": self.run_id,
"system_info": system_info,
**properties,
}
},
)
except Exception as e:
logger.warning(f"Failed to send telemetry event: {e}")
@@ -266,4 +265,4 @@ class TelemetryManager:
def init_telemetry_manager() -> TelemetryManager:
"""Initialize telemetry system"""
return TelemetryManager(TelemetryConfig())
return TelemetryManager(TelemetryConfig())

View File

@@ -8,4 +8,4 @@ organizations:
- "NousResearch"
- "allenai"
- "amd"
- "tiiuae"
- "tiiuae"

View File

@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
@pytest.fixture
@@ -12,7 +12,7 @@ def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing"""
whitelist_content = {
"organizations": ["meta", "mistral"],
"models": ["llama", "mistral-7b"]
"models": ["llama", "mistral-7b"],
}
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w") as f:
@@ -24,8 +24,7 @@ def mock_whitelist(tmp_path):
def config(mock_whitelist):
"""Create a TelemetryConfig with test settings"""
return TelemetryConfig(
host="https://test.posthog.com",
whitelist_path=mock_whitelist
host="https://test.posthog.com", whitelist_path=mock_whitelist
)
@@ -51,10 +50,7 @@ def test_telemetry_opt_in():
def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, {
"AXOLOTL_TELEMETRY": "1",
"DO_NOT_TRACK": "1"
}):
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
manager = TelemetryManager(TelemetryConfig())
assert not manager.enabled
@@ -77,7 +73,7 @@ def test_event_tracking(manager):
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"]["key"] == "value"
@@ -97,28 +93,31 @@ def test_model_tracking(manager):
tokenizer_config={},
flash_attention=True,
quantization_config=None,
training_approach="lora"
training_approach="lora",
)
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_model_load(model_config)
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "model_load"
assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict()
assert (
mock_capture.call_args[1]["properties"]["model_config"]
== model_config.to_dict()
)
def test_training_context(manager):
"""Test training context manager"""
config = {"model": "llama", "batch_size": 8}
with patch("posthog.capture") as mock_capture:
manager.enabled = True
with manager.track_training(config):
pass # Simulate successful training
# Should have captured training_start and training_complete
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events
@@ -128,14 +127,14 @@ def test_training_context(manager):
def test_training_error(manager):
"""Test training context manager with error"""
config = {"model": "llama", "batch_size": 8}
with patch("posthog.capture") as mock_capture:
manager.enabled = True
with pytest.raises(ValueError):
with manager.track_training(config):
raise ValueError("Test error")
# Should have captured training_start and training_error
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events