diff --git a/src/axolotl/telemetry/__init__.py b/src/axolotl/telemetry/__init__.py index 99edb167c..3aae57c11 100644 --- a/src/axolotl/telemetry/__init__.py +++ b/src/axolotl/telemetry/__init__.py @@ -1,5 +1,15 @@ """Init for axolotl.telemetry module.""" -from .manager import ModelConfig, TelemetryConfig, TelemetryManager, init_telemetry_manager +from .manager import ( + ModelConfig, + TelemetryConfig, + TelemetryManager, + init_telemetry_manager, +) -__all__ = ["TelemetryConfig", "TelemetryManager", "ModelConfig", "init_telemetry_manager"] +__all__ = [ + "TelemetryConfig", + "TelemetryManager", + "ModelConfig", + "init_telemetry_manager", +] diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index ae5f5e2aa..6b132713d 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -126,8 +126,7 @@ class TelemetryManager: base_model = base_model.lower() return any( - org.lower() in base_model - for org in self.whitelist.get("organizations", []) + org.lower() in base_model for org in self.whitelist.get("organizations", []) ) def _init_posthog(self): @@ -192,7 +191,7 @@ class TelemetryManager: "run_id": self.run_id, "system_info": system_info, **properties, - } + }, ) except Exception as e: logger.warning(f"Failed to send telemetry event: {e}") @@ -266,4 +265,4 @@ class TelemetryManager: def init_telemetry_manager() -> TelemetryManager: """Initialize telemetry system""" - return TelemetryManager(TelemetryConfig()) \ No newline at end of file + return TelemetryManager(TelemetryConfig()) diff --git a/src/axolotl/telemetry/whitelist.yaml b/src/axolotl/telemetry/whitelist.yaml index a0e3a5562..f7b5afecc 100644 --- a/src/axolotl/telemetry/whitelist.yaml +++ b/src/axolotl/telemetry/whitelist.yaml @@ -8,4 +8,4 @@ organizations: - "NousResearch" - "allenai" - "amd" - - "tiiuae" \ No newline at end of file + - "tiiuae" diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py index 43305f278..f042b3845 100644 --- a/tests/telemetry/test_manager.py +++ b/tests/telemetry/test_manager.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest import yaml -from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig +from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager @pytest.fixture @@ -12,7 +12,7 @@ def mock_whitelist(tmp_path): """Create a temporary whitelist file for testing""" whitelist_content = { "organizations": ["meta", "mistral"], - "models": ["llama", "mistral-7b"] + "models": ["llama", "mistral-7b"], } whitelist_file = tmp_path / "whitelist.yaml" with open(whitelist_file, "w") as f: @@ -24,8 +24,7 @@ def mock_whitelist(tmp_path): def config(mock_whitelist): """Create a TelemetryConfig with test settings""" return TelemetryConfig( - host="https://test.posthog.com", - whitelist_path=mock_whitelist + host="https://test.posthog.com", whitelist_path=mock_whitelist ) @@ -51,10 +50,7 @@ def test_telemetry_opt_in(): def test_do_not_track_override(): """Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY""" - with patch.dict(os.environ, { - "AXOLOTL_TELEMETRY": "1", - "DO_NOT_TRACK": "1" - }): + with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}): manager = TelemetryManager(TelemetryConfig()) assert not manager.enabled @@ -77,7 +73,7 @@ def test_event_tracking(manager): with patch("posthog.capture") as mock_capture: manager.enabled = True manager.track_event("test_event", {"key": "value"}) - + assert mock_capture.called assert mock_capture.call_args[1]["event"] == "test_event" assert mock_capture.call_args[1]["properties"]["key"] == "value" @@ -97,28 +93,31 @@ def test_model_tracking(manager): tokenizer_config={}, flash_attention=True, quantization_config=None, - training_approach="lora" + training_approach="lora", ) - + with patch("posthog.capture") as mock_capture: manager.enabled = True manager.track_model_load(model_config) - + assert mock_capture.called assert mock_capture.call_args[1]["event"] == "model_load" - assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict() + assert ( + mock_capture.call_args[1]["properties"]["model_config"] + == model_config.to_dict() + ) def test_training_context(manager): """Test training context manager""" config = {"model": "llama", "batch_size": 8} - + with patch("posthog.capture") as mock_capture: manager.enabled = True - + with manager.track_training(config): pass # Simulate successful training - + # Should have captured training_start and training_complete events = [call[1]["event"] for call in mock_capture.call_args_list] assert "training_start" in events @@ -128,14 +127,14 @@ def test_training_context(manager): def test_training_error(manager): """Test training context manager with error""" config = {"model": "llama", "batch_size": 8} - + with patch("posthog.capture") as mock_capture: manager.enabled = True - + with pytest.raises(ValueError): with manager.track_training(config): raise ValueError("Test error") - + # Should have captured training_start and training_error events = [call[1]["event"] for call in mock_capture.call_args_list] assert "training_start" in events