This commit is contained in:
Dan Saunders
2025-02-19 13:55:04 +00:00
parent 07ffd47f2b
commit c50610375f
4 changed files with 34 additions and 26 deletions

View File

@@ -1,5 +1,15 @@
"""Init for axolotl.telemetry module.""" """Init for axolotl.telemetry module."""
from .manager import ModelConfig, TelemetryConfig, TelemetryManager, init_telemetry_manager from .manager import (
ModelConfig,
TelemetryConfig,
TelemetryManager,
init_telemetry_manager,
)
__all__ = ["TelemetryConfig", "TelemetryManager", "ModelConfig", "init_telemetry_manager"] __all__ = [
"TelemetryConfig",
"TelemetryManager",
"ModelConfig",
"init_telemetry_manager",
]

View File

@@ -126,8 +126,7 @@ class TelemetryManager:
base_model = base_model.lower() base_model = base_model.lower()
return any( return any(
org.lower() in base_model org.lower() in base_model for org in self.whitelist.get("organizations", [])
for org in self.whitelist.get("organizations", [])
) )
def _init_posthog(self): def _init_posthog(self):
@@ -192,7 +191,7 @@ class TelemetryManager:
"run_id": self.run_id, "run_id": self.run_id,
"system_info": system_info, "system_info": system_info,
**properties, **properties,
} },
) )
except Exception as e: except Exception as e:
logger.warning(f"Failed to send telemetry event: {e}") logger.warning(f"Failed to send telemetry event: {e}")
@@ -266,4 +265,4 @@ class TelemetryManager:
def init_telemetry_manager() -> TelemetryManager: def init_telemetry_manager() -> TelemetryManager:
"""Initialize telemetry system""" """Initialize telemetry system"""
return TelemetryManager(TelemetryConfig()) return TelemetryManager(TelemetryConfig())

View File

@@ -8,4 +8,4 @@ organizations:
- "NousResearch" - "NousResearch"
- "allenai" - "allenai"
- "amd" - "amd"
- "tiiuae" - "tiiuae"

View File

@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest import pytest
import yaml import yaml
from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
@pytest.fixture @pytest.fixture
@@ -12,7 +12,7 @@ def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing""" """Create a temporary whitelist file for testing"""
whitelist_content = { whitelist_content = {
"organizations": ["meta", "mistral"], "organizations": ["meta", "mistral"],
"models": ["llama", "mistral-7b"] "models": ["llama", "mistral-7b"],
} }
whitelist_file = tmp_path / "whitelist.yaml" whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w") as f: with open(whitelist_file, "w") as f:
@@ -24,8 +24,7 @@ def mock_whitelist(tmp_path):
def config(mock_whitelist): def config(mock_whitelist):
"""Create a TelemetryConfig with test settings""" """Create a TelemetryConfig with test settings"""
return TelemetryConfig( return TelemetryConfig(
host="https://test.posthog.com", host="https://test.posthog.com", whitelist_path=mock_whitelist
whitelist_path=mock_whitelist
) )
@@ -51,10 +50,7 @@ def test_telemetry_opt_in():
def test_do_not_track_override(): def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY""" """Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, { with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
"AXOLOTL_TELEMETRY": "1",
"DO_NOT_TRACK": "1"
}):
manager = TelemetryManager(TelemetryConfig()) manager = TelemetryManager(TelemetryConfig())
assert not manager.enabled assert not manager.enabled
@@ -77,7 +73,7 @@ def test_event_tracking(manager):
with patch("posthog.capture") as mock_capture: with patch("posthog.capture") as mock_capture:
manager.enabled = True manager.enabled = True
manager.track_event("test_event", {"key": "value"}) manager.track_event("test_event", {"key": "value"})
assert mock_capture.called assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event" assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"]["key"] == "value" assert mock_capture.call_args[1]["properties"]["key"] == "value"
@@ -97,28 +93,31 @@ def test_model_tracking(manager):
tokenizer_config={}, tokenizer_config={},
flash_attention=True, flash_attention=True,
quantization_config=None, quantization_config=None,
training_approach="lora" training_approach="lora",
) )
with patch("posthog.capture") as mock_capture: with patch("posthog.capture") as mock_capture:
manager.enabled = True manager.enabled = True
manager.track_model_load(model_config) manager.track_model_load(model_config)
assert mock_capture.called assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "model_load" assert mock_capture.call_args[1]["event"] == "model_load"
assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict() assert (
mock_capture.call_args[1]["properties"]["model_config"]
== model_config.to_dict()
)
def test_training_context(manager): def test_training_context(manager):
"""Test training context manager""" """Test training context manager"""
config = {"model": "llama", "batch_size": 8} config = {"model": "llama", "batch_size": 8}
with patch("posthog.capture") as mock_capture: with patch("posthog.capture") as mock_capture:
manager.enabled = True manager.enabled = True
with manager.track_training(config): with manager.track_training(config):
pass # Simulate successful training pass # Simulate successful training
# Should have captured training_start and training_complete # Should have captured training_start and training_complete
events = [call[1]["event"] for call in mock_capture.call_args_list] events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events assert "training_start" in events
@@ -128,14 +127,14 @@ def test_training_context(manager):
def test_training_error(manager): def test_training_error(manager):
"""Test training context manager with error""" """Test training context manager with error"""
config = {"model": "llama", "batch_size": 8} config = {"model": "llama", "batch_size": 8}
with patch("posthog.capture") as mock_capture: with patch("posthog.capture") as mock_capture:
manager.enabled = True manager.enabled = True
with pytest.raises(ValueError): with pytest.raises(ValueError):
with manager.track_training(config): with manager.track_training(config):
raise ValueError("Test error") raise ValueError("Test error")
# Should have captured training_start and training_error # Should have captured training_start and training_error
events = [call[1]["event"] for call in mock_capture.call_args_list] events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events assert "training_start" in events