updates
This commit is contained in:
@@ -1,5 +1,15 @@
|
|||||||
"""Init for axolotl.telemetry module."""
|
"""Init for axolotl.telemetry module."""
|
||||||
|
|
||||||
from .manager import ModelConfig, TelemetryConfig, TelemetryManager, init_telemetry_manager
|
from .manager import (
|
||||||
|
ModelConfig,
|
||||||
|
TelemetryConfig,
|
||||||
|
TelemetryManager,
|
||||||
|
init_telemetry_manager,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["TelemetryConfig", "TelemetryManager", "ModelConfig", "init_telemetry_manager"]
|
__all__ = [
|
||||||
|
"TelemetryConfig",
|
||||||
|
"TelemetryManager",
|
||||||
|
"ModelConfig",
|
||||||
|
"init_telemetry_manager",
|
||||||
|
]
|
||||||
|
|||||||
@@ -126,8 +126,7 @@ class TelemetryManager:
|
|||||||
|
|
||||||
base_model = base_model.lower()
|
base_model = base_model.lower()
|
||||||
return any(
|
return any(
|
||||||
org.lower() in base_model
|
org.lower() in base_model for org in self.whitelist.get("organizations", [])
|
||||||
for org in self.whitelist.get("organizations", [])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_posthog(self):
|
def _init_posthog(self):
|
||||||
@@ -192,7 +191,7 @@ class TelemetryManager:
|
|||||||
"run_id": self.run_id,
|
"run_id": self.run_id,
|
||||||
"system_info": system_info,
|
"system_info": system_info,
|
||||||
**properties,
|
**properties,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to send telemetry event: {e}")
|
logger.warning(f"Failed to send telemetry event: {e}")
|
||||||
@@ -266,4 +265,4 @@ class TelemetryManager:
|
|||||||
|
|
||||||
def init_telemetry_manager() -> TelemetryManager:
|
def init_telemetry_manager() -> TelemetryManager:
|
||||||
"""Initialize telemetry system"""
|
"""Initialize telemetry system"""
|
||||||
return TelemetryManager(TelemetryConfig())
|
return TelemetryManager(TelemetryConfig())
|
||||||
|
|||||||
@@ -8,4 +8,4 @@ organizations:
|
|||||||
- "NousResearch"
|
- "NousResearch"
|
||||||
- "allenai"
|
- "allenai"
|
||||||
- "amd"
|
- "amd"
|
||||||
- "tiiuae"
|
- "tiiuae"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig
|
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -12,7 +12,7 @@ def mock_whitelist(tmp_path):
|
|||||||
"""Create a temporary whitelist file for testing"""
|
"""Create a temporary whitelist file for testing"""
|
||||||
whitelist_content = {
|
whitelist_content = {
|
||||||
"organizations": ["meta", "mistral"],
|
"organizations": ["meta", "mistral"],
|
||||||
"models": ["llama", "mistral-7b"]
|
"models": ["llama", "mistral-7b"],
|
||||||
}
|
}
|
||||||
whitelist_file = tmp_path / "whitelist.yaml"
|
whitelist_file = tmp_path / "whitelist.yaml"
|
||||||
with open(whitelist_file, "w") as f:
|
with open(whitelist_file, "w") as f:
|
||||||
@@ -24,8 +24,7 @@ def mock_whitelist(tmp_path):
|
|||||||
def config(mock_whitelist):
|
def config(mock_whitelist):
|
||||||
"""Create a TelemetryConfig with test settings"""
|
"""Create a TelemetryConfig with test settings"""
|
||||||
return TelemetryConfig(
|
return TelemetryConfig(
|
||||||
host="https://test.posthog.com",
|
host="https://test.posthog.com", whitelist_path=mock_whitelist
|
||||||
whitelist_path=mock_whitelist
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -51,10 +50,7 @@ def test_telemetry_opt_in():
|
|||||||
|
|
||||||
def test_do_not_track_override():
|
def test_do_not_track_override():
|
||||||
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
||||||
with patch.dict(os.environ, {
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
||||||
"AXOLOTL_TELEMETRY": "1",
|
|
||||||
"DO_NOT_TRACK": "1"
|
|
||||||
}):
|
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager(TelemetryConfig())
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
@@ -77,7 +73,7 @@ def test_event_tracking(manager):
|
|||||||
with patch("posthog.capture") as mock_capture:
|
with patch("posthog.capture") as mock_capture:
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
manager.track_event("test_event", {"key": "value"})
|
manager.track_event("test_event", {"key": "value"})
|
||||||
|
|
||||||
assert mock_capture.called
|
assert mock_capture.called
|
||||||
assert mock_capture.call_args[1]["event"] == "test_event"
|
assert mock_capture.call_args[1]["event"] == "test_event"
|
||||||
assert mock_capture.call_args[1]["properties"]["key"] == "value"
|
assert mock_capture.call_args[1]["properties"]["key"] == "value"
|
||||||
@@ -97,28 +93,31 @@ def test_model_tracking(manager):
|
|||||||
tokenizer_config={},
|
tokenizer_config={},
|
||||||
flash_attention=True,
|
flash_attention=True,
|
||||||
quantization_config=None,
|
quantization_config=None,
|
||||||
training_approach="lora"
|
training_approach="lora",
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
with patch("posthog.capture") as mock_capture:
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
manager.track_model_load(model_config)
|
manager.track_model_load(model_config)
|
||||||
|
|
||||||
assert mock_capture.called
|
assert mock_capture.called
|
||||||
assert mock_capture.call_args[1]["event"] == "model_load"
|
assert mock_capture.call_args[1]["event"] == "model_load"
|
||||||
assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict()
|
assert (
|
||||||
|
mock_capture.call_args[1]["properties"]["model_config"]
|
||||||
|
== model_config.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_training_context(manager):
|
def test_training_context(manager):
|
||||||
"""Test training context manager"""
|
"""Test training context manager"""
|
||||||
config = {"model": "llama", "batch_size": 8}
|
config = {"model": "llama", "batch_size": 8}
|
||||||
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
with patch("posthog.capture") as mock_capture:
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
|
|
||||||
with manager.track_training(config):
|
with manager.track_training(config):
|
||||||
pass # Simulate successful training
|
pass # Simulate successful training
|
||||||
|
|
||||||
# Should have captured training_start and training_complete
|
# Should have captured training_start and training_complete
|
||||||
events = [call[1]["event"] for call in mock_capture.call_args_list]
|
events = [call[1]["event"] for call in mock_capture.call_args_list]
|
||||||
assert "training_start" in events
|
assert "training_start" in events
|
||||||
@@ -128,14 +127,14 @@ def test_training_context(manager):
|
|||||||
def test_training_error(manager):
|
def test_training_error(manager):
|
||||||
"""Test training context manager with error"""
|
"""Test training context manager with error"""
|
||||||
config = {"model": "llama", "batch_size": 8}
|
config = {"model": "llama", "batch_size": 8}
|
||||||
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
with patch("posthog.capture") as mock_capture:
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
with manager.track_training(config):
|
with manager.track_training(config):
|
||||||
raise ValueError("Test error")
|
raise ValueError("Test error")
|
||||||
|
|
||||||
# Should have captured training_start and training_error
|
# Should have captured training_start and training_error
|
||||||
events = [call[1]["event"] for call in mock_capture.call_args_list]
|
events = [call[1]["event"] for call in mock_capture.call_args_list]
|
||||||
assert "training_start" in events
|
assert "training_start" in events
|
||||||
|
|||||||
Reference in New Issue
Block a user