updates
This commit is contained in:
@@ -1,206 +0,0 @@
|
|||||||
"""Telemetry manager and associated utilities."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import threading
|
|
||||||
import uuid
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Empty, Full, Queue
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import posthog
|
|
||||||
import psutil
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TelemetryConfig:
|
|
||||||
"""Configuration for telemetry system"""
|
|
||||||
|
|
||||||
enabled: bool
|
|
||||||
project_api_key: str
|
|
||||||
host: str = "https://app.posthog.com" # TODO: replace with self-hosted endpoint?
|
|
||||||
queue_size: int = 100
|
|
||||||
batch_size: int = 10
|
|
||||||
whitelist_path: str = "telemetry_whitelist.yaml"
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
|
||||||
"""Manages telemetry collection and transmission"""
|
|
||||||
|
|
||||||
def __init__(self, config: TelemetryConfig):
|
|
||||||
"""
|
|
||||||
Telemetry manager constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Telemetry configuration object.
|
|
||||||
"""
|
|
||||||
self.config = config
|
|
||||||
self.enabled = self._check_telemetry_enabled()
|
|
||||||
self.run_id = str(uuid.uuid4())
|
|
||||||
self.event_queue: Queue = Queue(maxsize=config.queue_size)
|
|
||||||
|
|
||||||
if self.enabled:
|
|
||||||
self._init_posthog()
|
|
||||||
self._start_worker()
|
|
||||||
|
|
||||||
def _check_telemetry_enabled(self) -> bool:
|
|
||||||
"""Check if telemetry is enabled based on environment variables"""
|
|
||||||
if not self.config.enabled:
|
|
||||||
return False
|
|
||||||
|
|
||||||
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true")
|
|
||||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK", "0").lower() in (
|
|
||||||
"1",
|
|
||||||
"true",
|
|
||||||
)
|
|
||||||
|
|
||||||
return not (do_not_track or axolotl_do_not_track)
|
|
||||||
|
|
||||||
def _init_posthog(self):
|
|
||||||
"""Initialize PostHog client"""
|
|
||||||
posthog.project_api_key = self.config.project_api_key
|
|
||||||
posthog.host = self.config.host
|
|
||||||
|
|
||||||
def _start_worker(self):
|
|
||||||
"""Start background worker thread for processing events"""
|
|
||||||
self.worker_thread = threading.Thread(target=self._process_queue, daemon=True)
|
|
||||||
self.worker_thread.start()
|
|
||||||
|
|
||||||
def _process_queue(self):
|
|
||||||
"""Process events from queue and send to PostHog"""
|
|
||||||
while True:
|
|
||||||
events = []
|
|
||||||
# Always get at least one event (blocking)
|
|
||||||
events.append(self.event_queue.get())
|
|
||||||
|
|
||||||
# Try to get more events up to batch size (non-blocking)
|
|
||||||
remaining_batch = self.config.batch_size - 1
|
|
||||||
for _ in range(remaining_batch):
|
|
||||||
try:
|
|
||||||
event = self.event_queue.get_nowait()
|
|
||||||
events.append(event)
|
|
||||||
except Empty:
|
|
||||||
# No more events available right now
|
|
||||||
break
|
|
||||||
|
|
||||||
if events:
|
|
||||||
try:
|
|
||||||
posthog.capture_batch(events)
|
|
||||||
except (posthog.RequestError, posthog.RateLimitError) as e:
|
|
||||||
logger.warning(f"Failed to send telemetry batch: {e}")
|
|
||||||
except ConnectionError as e:
|
|
||||||
logger.warning(f"Network error while sending telemetry: {e}")
|
|
||||||
finally:
|
|
||||||
# Mark tasks as done even if sending failed
|
|
||||||
for _ in range(len(events)):
|
|
||||||
self.event_queue.task_done()
|
|
||||||
|
|
||||||
def _sanitize_path(self, path: str) -> str:
|
|
||||||
"""Remove personal information from file paths"""
|
|
||||||
return Path(path).name
|
|
||||||
|
|
||||||
def _sanitize_error(self, error: str) -> str:
|
|
||||||
"""Remove personal information from error messages"""
|
|
||||||
# Replace file paths with just filename
|
|
||||||
sanitized = error
|
|
||||||
try:
|
|
||||||
for path in Path(error).parents:
|
|
||||||
sanitized = sanitized.replace(str(path), "")
|
|
||||||
except (ValueError, RuntimeError) as e:
|
|
||||||
# ValueError: Invalid path format
|
|
||||||
# RuntimeError: Other path parsing errors
|
|
||||||
logger.debug(f"Could not parse path in error message: {e}")
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
def _get_system_info(self) -> dict[str, Any]:
|
|
||||||
"""Collect system information"""
|
|
||||||
gpu_info = []
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
gpu_info.append(
|
|
||||||
{
|
|
||||||
"name": torch.cuda.get_device_name(i),
|
|
||||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"os": platform.system(),
|
|
||||||
"python_version": platform.python_version(),
|
|
||||||
"cpu_count": psutil.cpu_count(),
|
|
||||||
"memory_total": psutil.virtual_memory().total,
|
|
||||||
"gpu_count": len(gpu_info),
|
|
||||||
"gpu_info": gpu_info,
|
|
||||||
}
|
|
||||||
|
|
||||||
def track_event(self, event_type: str, properties: dict[str, Any]):
|
|
||||||
"""Track a telemetry event"""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get system info first - most likely source of errors
|
|
||||||
system_info = self._get_system_info()
|
|
||||||
|
|
||||||
# Construct event dict - could raise TypeError if properties aren't serializable
|
|
||||||
event = {
|
|
||||||
"event": event_type,
|
|
||||||
"properties": {
|
|
||||||
"run_id": self.run_id,
|
|
||||||
"system_info": system_info,
|
|
||||||
**properties,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.event_queue.put_nowait(event)
|
|
||||||
except Full:
|
|
||||||
logger.warning("Telemetry queue full, dropping event")
|
|
||||||
except (RuntimeError, OSError) as e:
|
|
||||||
# Hardware info collection errors
|
|
||||||
logger.warning(f"Failed to collect system info for telemetry: {e}")
|
|
||||||
except TypeError as e:
|
|
||||||
# Dict construction/serialization errors
|
|
||||||
logger.warning(f"Invalid property type in telemetry event: {e}")
|
|
||||||
except AttributeError as e:
|
|
||||||
# Missing attributes when collecting system info
|
|
||||||
logger.warning(f"Failed to access system attribute for telemetry: {e}")
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def track_training(self, config: dict[str, Any]):
|
|
||||||
"""Context manager to track training run"""
|
|
||||||
if not self.enabled:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
# Track training start
|
|
||||||
sanitized_config = {
|
|
||||||
k: v
|
|
||||||
for k, v in config.items()
|
|
||||||
if not any(p in k.lower() for p in ["path", "dir", "file"])
|
|
||||||
}
|
|
||||||
|
|
||||||
self.track_event("training_start", {"config": sanitized_config})
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
# Track successful completion
|
|
||||||
self.track_event("training_complete", {})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Track error
|
|
||||||
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def init_telemetry(project_api_key: str, enabled: bool = True) -> TelemetryManager:
|
|
||||||
"""Initialize telemetry system"""
|
|
||||||
config = TelemetryConfig(enabled=enabled, project_api_key=project_api_key)
|
|
||||||
return TelemetryManager(config)
|
|
||||||
5
src/axolotl/telemetry/__init__.py
Normal file
5
src/axolotl/telemetry/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Init for axolotl.telemetry module."""
|
||||||
|
|
||||||
|
from .manager import ModelConfig, TelemetryConfig, TelemetryManager, init_telemetry_manager
|
||||||
|
|
||||||
|
__all__ = ["TelemetryConfig", "TelemetryManager", "ModelConfig", "init_telemetry_manager"]
|
||||||
269
src/axolotl/telemetry/manager.py
Normal file
269
src/axolotl/telemetry/manager.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""Telemetry manager and associated utilities."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import posthog
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
"""Tracked model configuration details"""
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_layers: int
|
||||||
|
num_attention_heads: int
|
||||||
|
tokenizer_config: dict
|
||||||
|
flash_attention: bool
|
||||||
|
quantization_config: dict | None
|
||||||
|
training_approach: str # 'lora', 'qlora', 'full_finetune'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict) -> "ModelConfig":
|
||||||
|
"""Create from Axolotl config dict"""
|
||||||
|
return cls(
|
||||||
|
base_model=config.get("base_model", ""),
|
||||||
|
model_type=config.get("model_type", ""),
|
||||||
|
hidden_size=config.get("hidden_size", 0),
|
||||||
|
num_layers=config.get("num_layers", 0),
|
||||||
|
num_attention_heads=config.get("num_attention_heads", 0),
|
||||||
|
tokenizer_config=config.get("tokenizer", {}),
|
||||||
|
flash_attention=config.get("flash_attention", False),
|
||||||
|
quantization_config=config.get("quantization", None),
|
||||||
|
training_approach=config.get("training_approach", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to PostHog-compatible dict"""
|
||||||
|
return {
|
||||||
|
"base_model": self.base_model,
|
||||||
|
"model_type": self.model_type,
|
||||||
|
"architecture": {
|
||||||
|
"hidden_size": self.hidden_size,
|
||||||
|
"num_layers": self.num_layers,
|
||||||
|
"num_attention_heads": self.num_attention_heads,
|
||||||
|
},
|
||||||
|
"optimizations": {
|
||||||
|
"flash_attention": self.flash_attention,
|
||||||
|
"quantization": self.quantization_config is not None,
|
||||||
|
"quantization_config": self.quantization_config,
|
||||||
|
},
|
||||||
|
"training_approach": self.training_approach,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TelemetryConfig:
|
||||||
|
"""Configuration for telemetry manager"""
|
||||||
|
|
||||||
|
host: str = "https://app.posthog.com"
|
||||||
|
queue_size: int = 100
|
||||||
|
batch_size: int = 10
|
||||||
|
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
||||||
|
retention_days: int = 365
|
||||||
|
distinct_id: str = str(uuid.uuid4())
|
||||||
|
schema_version: str = "0.1.0"
|
||||||
|
|
||||||
|
|
||||||
|
class TelemetryManager:
|
||||||
|
"""Manages telemetry collection and transmission"""
|
||||||
|
|
||||||
|
def __init__(self, config: TelemetryConfig):
|
||||||
|
"""
|
||||||
|
Telemetry manager constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Telemetry configuration object.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.run_id = str(uuid.uuid4())
|
||||||
|
self.enabled = self._check_telemetry_enabled()
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
self.whitelist = self._load_whitelist()
|
||||||
|
self._init_posthog()
|
||||||
|
|
||||||
|
def _check_telemetry_enabled(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if telemetry is enabled based on environment variables.
|
||||||
|
|
||||||
|
Note: This is enabled on an opt-in basis. Please consider setting
|
||||||
|
`AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're
|
||||||
|
using so we can focus our engineering efforts!
|
||||||
|
"""
|
||||||
|
# Only enable if explicitly opted in
|
||||||
|
axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true")
|
||||||
|
|
||||||
|
# Respect DO_NOT_TRACK as an override even if telemetry is enabled
|
||||||
|
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true")
|
||||||
|
|
||||||
|
return axolotl_telemetry and not do_not_track
|
||||||
|
|
||||||
|
def _load_whitelist(self) -> dict:
|
||||||
|
"""Load organization/model whitelist"""
|
||||||
|
with open(self.config.whitelist_path, encoding="utf-8") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
def _is_whitelisted(self, base_model: str) -> bool:
|
||||||
|
"""Check if model/org is in whitelist"""
|
||||||
|
if not base_model:
|
||||||
|
return False
|
||||||
|
|
||||||
|
base_model = base_model.lower()
|
||||||
|
return any(
|
||||||
|
org.lower() in base_model
|
||||||
|
for org in self.whitelist.get("organizations", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_posthog(self):
|
||||||
|
"""Initialize PostHog client"""
|
||||||
|
posthog.project_api_key = POSTHOG_WRITE_KEY
|
||||||
|
posthog.host = self.config.host
|
||||||
|
|
||||||
|
def _sanitize_path(self, path: str) -> str:
|
||||||
|
"""Remove personal information from file paths"""
|
||||||
|
return Path(path).name
|
||||||
|
|
||||||
|
def _sanitize_error(self, error: str) -> str:
|
||||||
|
"""Remove personal information from error messages"""
|
||||||
|
# Replace file paths with just filename
|
||||||
|
sanitized = error
|
||||||
|
try:
|
||||||
|
for path in Path(error).parents:
|
||||||
|
sanitized = sanitized.replace(str(path), "")
|
||||||
|
except (ValueError, RuntimeError) as e:
|
||||||
|
# ValueError: Invalid path format
|
||||||
|
# RuntimeError: Other path parsing errors
|
||||||
|
logger.debug(f"Could not parse path in error message: {e}")
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def _get_system_info(self) -> dict[str, Any]:
|
||||||
|
"""Collect system information"""
|
||||||
|
gpu_info = []
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
gpu_info.append(
|
||||||
|
{
|
||||||
|
"name": torch.cuda.get_device_name(i),
|
||||||
|
"memory": torch.cuda.get_device_properties(i).total_memory,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"os": platform.system(),
|
||||||
|
"python_version": platform.python_version(),
|
||||||
|
"cpu_count": psutil.cpu_count(),
|
||||||
|
"memory_total": psutil.virtual_memory().total,
|
||||||
|
"gpu_count": len(gpu_info),
|
||||||
|
"gpu_info": gpu_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
def track_event(self, event_type: str, properties: dict[str, Any]):
|
||||||
|
"""Track a telemetry event"""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get system info first - most likely source of errors
|
||||||
|
system_info = self._get_system_info()
|
||||||
|
|
||||||
|
# Send event via PostHog
|
||||||
|
try:
|
||||||
|
posthog.capture(
|
||||||
|
distinct_id=self.config.distinct_id,
|
||||||
|
event=event_type,
|
||||||
|
properties={
|
||||||
|
"run_id": self.run_id,
|
||||||
|
"system_info": system_info,
|
||||||
|
**properties,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send telemetry event: {e}")
|
||||||
|
except (RuntimeError, OSError) as e:
|
||||||
|
logger.warning(f"Failed to collect system info for telemetry: {e}")
|
||||||
|
except TypeError as e:
|
||||||
|
logger.warning(f"Invalid property type in telemetry event: {e}")
|
||||||
|
except AttributeError as e:
|
||||||
|
logger.warning(f"Failed to access system attribute for telemetry: {e}")
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def track_training(self, config: dict[str, Any]):
|
||||||
|
"""Context manager to track training run"""
|
||||||
|
if not self.enabled:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Track training start
|
||||||
|
sanitized_config = {
|
||||||
|
k: v
|
||||||
|
for k, v in config.items()
|
||||||
|
if not any(p in k.lower() for p in ["path", "dir", "file"])
|
||||||
|
}
|
||||||
|
|
||||||
|
self.track_event("training_start", {"config": sanitized_config})
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
# Track successful completion
|
||||||
|
self.track_event("training_complete", {})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Track error
|
||||||
|
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
|
||||||
|
raise
|
||||||
|
|
||||||
|
def track_model_load(self, model_config: ModelConfig):
|
||||||
|
"""Track model loading and configuration"""
|
||||||
|
if not self.enabled or not self._is_whitelisted(model_config.base_model):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.track_event(
|
||||||
|
"model_load",
|
||||||
|
{
|
||||||
|
"model_config": model_config.to_dict(),
|
||||||
|
"system_info": self._get_system_info(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def track_training_metrics(self, metrics: dict):
|
||||||
|
"""Track training progress metrics"""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.track_event(
|
||||||
|
"training_metrics",
|
||||||
|
{
|
||||||
|
"duration": metrics.get("duration"),
|
||||||
|
"peak_memory": metrics.get("peak_memory"),
|
||||||
|
"steps_completed": metrics.get("steps_completed"),
|
||||||
|
"current_loss": metrics.get("loss"),
|
||||||
|
"learning_rate": metrics.get("learning_rate"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Ensure all queued events are processed before shutdown"""
|
||||||
|
if self.enabled:
|
||||||
|
posthog.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def init_telemetry_manager() -> TelemetryManager:
|
||||||
|
"""Initialize telemetry system"""
|
||||||
|
return TelemetryManager(TelemetryConfig())
|
||||||
11
src/axolotl/telemetry/whitelist.yaml
Normal file
11
src/axolotl/telemetry/whitelist.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
organizations:
|
||||||
|
- "meta-llama"
|
||||||
|
- "huggingface"
|
||||||
|
- "nvidia"
|
||||||
|
- "facebook"
|
||||||
|
- "briaai"
|
||||||
|
- "unsloth"
|
||||||
|
- "NousResearch"
|
||||||
|
- "allenai"
|
||||||
|
- "amd"
|
||||||
|
- "tiiuae"
|
||||||
166
tests/telemetry/test_manager.py
Normal file
166
tests/telemetry/test_manager.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_whitelist(tmp_path):
|
||||||
|
"""Create a temporary whitelist file for testing"""
|
||||||
|
whitelist_content = {
|
||||||
|
"organizations": ["meta", "mistral"],
|
||||||
|
"models": ["llama", "mistral-7b"]
|
||||||
|
}
|
||||||
|
whitelist_file = tmp_path / "whitelist.yaml"
|
||||||
|
with open(whitelist_file, "w") as f:
|
||||||
|
yaml.dump(whitelist_content, f)
|
||||||
|
return str(whitelist_file)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(mock_whitelist):
|
||||||
|
"""Create a TelemetryConfig with test settings"""
|
||||||
|
return TelemetryConfig(
|
||||||
|
host="https://test.posthog.com",
|
||||||
|
whitelist_path=mock_whitelist
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(config):
|
||||||
|
"""Create a TelemetryManager instance with mocked PostHog"""
|
||||||
|
with patch("posthog.capture"):
|
||||||
|
return TelemetryManager(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_telemetry_disabled_by_default():
|
||||||
|
"""Test that telemetry is disabled by default"""
|
||||||
|
manager = TelemetryManager(TelemetryConfig())
|
||||||
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
|
def test_telemetry_opt_in():
|
||||||
|
"""Test that telemetry can be enabled via environment variable"""
|
||||||
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||||
|
manager = TelemetryManager(TelemetryConfig())
|
||||||
|
assert manager.enabled
|
||||||
|
|
||||||
|
|
||||||
|
def test_do_not_track_override():
|
||||||
|
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AXOLOTL_TELEMETRY": "1",
|
||||||
|
"DO_NOT_TRACK": "1"
|
||||||
|
}):
|
||||||
|
manager = TelemetryManager(TelemetryConfig())
|
||||||
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
|
def test_whitelist_checking(manager):
|
||||||
|
"""Test model whitelist functionality"""
|
||||||
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||||
|
# Should match organization
|
||||||
|
assert manager._is_whitelisted("meta/llama-7b")
|
||||||
|
# Should match model name
|
||||||
|
assert manager._is_whitelisted("mistral-7b-instruct")
|
||||||
|
# Should not match
|
||||||
|
assert not manager._is_whitelisted("unknown/model")
|
||||||
|
# Should handle case insensitively
|
||||||
|
assert manager._is_whitelisted("meta/Llama-7b")
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_tracking(manager):
|
||||||
|
"""Test basic event tracking"""
|
||||||
|
with patch("posthog.capture") as mock_capture:
|
||||||
|
manager.enabled = True
|
||||||
|
manager.track_event("test_event", {"key": "value"})
|
||||||
|
|
||||||
|
assert mock_capture.called
|
||||||
|
assert mock_capture.call_args[1]["event"] == "test_event"
|
||||||
|
assert mock_capture.call_args[1]["properties"]["key"] == "value"
|
||||||
|
assert "run_id" in mock_capture.call_args[1]["properties"]
|
||||||
|
assert "system_info" in mock_capture.call_args[1]["properties"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_tracking(manager):
|
||||||
|
"""Test model load tracking"""
|
||||||
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||||
|
model_config = ModelConfig(
|
||||||
|
base_model="meta/llama-7b",
|
||||||
|
model_type="decoder",
|
||||||
|
hidden_size=4096,
|
||||||
|
num_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
tokenizer_config={},
|
||||||
|
flash_attention=True,
|
||||||
|
quantization_config=None,
|
||||||
|
training_approach="lora"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("posthog.capture") as mock_capture:
|
||||||
|
manager.enabled = True
|
||||||
|
manager.track_model_load(model_config)
|
||||||
|
|
||||||
|
assert mock_capture.called
|
||||||
|
assert mock_capture.call_args[1]["event"] == "model_load"
|
||||||
|
assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_context(manager):
|
||||||
|
"""Test training context manager"""
|
||||||
|
config = {"model": "llama", "batch_size": 8}
|
||||||
|
|
||||||
|
with patch("posthog.capture") as mock_capture:
|
||||||
|
manager.enabled = True
|
||||||
|
|
||||||
|
with manager.track_training(config):
|
||||||
|
pass # Simulate successful training
|
||||||
|
|
||||||
|
# Should have captured training_start and training_complete
|
||||||
|
events = [call[1]["event"] for call in mock_capture.call_args_list]
|
||||||
|
assert "training_start" in events
|
||||||
|
assert "training_complete" in events
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_error(manager):
|
||||||
|
"""Test training context manager with error"""
|
||||||
|
config = {"model": "llama", "batch_size": 8}
|
||||||
|
|
||||||
|
with patch("posthog.capture") as mock_capture:
|
||||||
|
manager.enabled = True
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
with manager.track_training(config):
|
||||||
|
raise ValueError("Test error")
|
||||||
|
|
||||||
|
# Should have captured training_start and training_error
|
||||||
|
events = [call[1]["event"] for call in mock_capture.call_args_list]
|
||||||
|
assert "training_start" in events
|
||||||
|
assert "training_error" in events
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_sanitization(manager):
|
||||||
|
"""Test path sanitization"""
|
||||||
|
path = "/home/user/sensitive/data.txt"
|
||||||
|
sanitized = manager._sanitize_path(path)
|
||||||
|
assert sanitized == "data.txt"
|
||||||
|
assert "/home/user" not in sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_sanitization(manager):
|
||||||
|
"""Test error message sanitization"""
|
||||||
|
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
||||||
|
sanitized = manager._sanitize_error(error)
|
||||||
|
assert "sensitive" not in sanitized
|
||||||
|
assert "/home/user" not in sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def test_shutdown(manager):
|
||||||
|
"""Test shutdown behavior"""
|
||||||
|
with patch("posthog.flush") as mock_flush:
|
||||||
|
manager.enabled = True
|
||||||
|
manager.shutdown()
|
||||||
|
assert mock_flush.called
|
||||||
Reference in New Issue
Block a user