progress on telemetry: config load, process, model load, train start / end, error tracking

This commit is contained in:
Dan Saunders
2025-02-19 22:05:12 +00:00
parent 90b39ce112
commit 66c6fb56cb
12 changed files with 227 additions and 217 deletions

6
docs/telemetry.qmd Normal file
View File

@@ -0,0 +1,6 @@
---
title: Telemetry
description: A description of the opt-out telemetry implementation in Axolotl.
---
TODO.

View File

@@ -14,6 +14,8 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry import TelemetryManager
from axolotl.telemetry.manager import track_errors
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -28,6 +30,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__, use_environ=True)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -159,6 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg
@track_errors
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
@@ -192,6 +197,8 @@ def load_cfg(
temp_file.close()
cfg.axolotl_config_path = temp_file.name
TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
@@ -233,4 +240,6 @@ def load_cfg(
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg)
return cfg

View File

@@ -162,6 +162,7 @@ def load_lora(
return model, lora_config
@send_errors
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,

View File

@@ -145,6 +145,7 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora"
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.

View File

@@ -14,6 +14,7 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this?

View File

@@ -117,6 +117,7 @@ def modify_tokenizer_files(
return tokenizer_dir
@send_errors
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)

View File

@@ -1,15 +1,8 @@
"""Init for axolotl.telemetry module."""
from .manager import (
ModelConfig,
TelemetryConfig,
TelemetryManager,
init_telemetry_manager,
)
from .manager import TelemetryConfig, TelemetryManager
__all__ = [
"TelemetryConfig",
"TelemetryManager",
"ModelConfig",
"init_telemetry_manager",
]

View File

@@ -1,70 +1,46 @@
"""Telemetry manager and associated utilities."""
import atexit
import logging
import os
import platform
import time
import traceback
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from functools import wraps
from inspect import getmodule
from pathlib import Path
from typing import Any
from typing import Any, Callable
import posthog
import psutil
import torch
import transformers
import yaml
logger = logging.getLogger(__name__)
import axolotl
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger(__name__)
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
@dataclass
class ModelConfig:
"""Tracked model configuration details"""
base_model: str
model_type: str
hidden_size: int
num_layers: int
num_attention_heads: int
tokenizer_config: dict
flash_attention: bool
quantization_config: dict | None
training_approach: str # 'lora', 'qlora', 'full_finetune'
@classmethod
def from_config(cls, config: dict) -> "ModelConfig":
"""Create from Axolotl config dict"""
return cls(
base_model=config.get("base_model", ""),
model_type=config.get("model_type", ""),
hidden_size=config.get("hidden_size", 0),
num_layers=config.get("num_layers", 0),
num_attention_heads=config.get("num_attention_heads", 0),
tokenizer_config=config.get("tokenizer", {}),
flash_attention=config.get("flash_attention", False),
quantization_config=config.get("quantization", None),
training_approach=config.get("training_approach", ""),
)
def to_dict(self) -> dict:
"""Convert to PostHog-compatible dict"""
return {
"base_model": self.base_model,
"model_type": self.model_type,
"architecture": {
"hidden_size": self.hidden_size,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
},
"optimizations": {
"flash_attention": self.flash_attention,
"quantization": self.quantization_config is not None,
"quantization_config": self.quantization_config,
},
"training_approach": self.training_approach,
}
ENABLED_WARNING_SLEEP_SECONDS = 10
ENABLED_WARNING = (
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
"- Which models and configurations are most commonly used\n"
"- What hardware setups need to be supported\n"
"- Where users encounter errors\n\n"
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
"To disable telemetry, set either:\n"
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
"- DO_NOT_TRACK=1 (Global standard)\n\n"
"To remove this warning and continue with telemetry enabled,"
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
"No personally identifiable information is collected."
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
)
@dataclass
@@ -76,43 +52,95 @@ class TelemetryConfig:
batch_size: int = 10
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
retention_days: int = 365
distinct_id: str = str(uuid.uuid4())
schema_version: str = "0.1.0"
class TelemetryManager:
"""Manages telemetry collection and transmission"""
def __init__(self, config: TelemetryConfig):
"""
Telemetry manager constructor.
_instance = None
_initialized = False
Args:
config: Telemetry configuration object.
def __new__(cls):
"""
self.config = config
self.run_id = str(uuid.uuid4())
self.enabled = self._check_telemetry_enabled()
Telemetry manager constructor. Creates the singleton instance of this class if
it doesn't already exist.
"""
if cls._instance is None:
cls._instance = super(TelemetryManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Telemetry manager initializer"""
if self._initialized:
return
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
if self.enabled:
# Warn about telemetry collection
if not self.explicit_enable:
LOG.warning(ENABLED_WARNING)
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
self.config = TelemetryConfig()
self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist()
self.system_info = self._get_system_info()
self._init_posthog()
def _check_telemetry_enabled(self) -> bool:
# Register shutdown method to flush posthog telemetry
atexit.register(self.shutdown)
self._initialized = True
@classmethod
def get_instance(cls) -> "TelemetryManager":
if cls._instance is None:
cls._instance = TelemetryManager()
return cls._instance
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
"""
Check if telemetry is enabled based on environment variables.
Check if telemetry is enabled based on environment variables. We also check
whether this is the main process (for the distributed setting and to avoid
sending duplicate PostHog events per GPU).
Note: This is enabled on an opt-in basis. Please consider setting
`AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're
using so we can focus our engineering efforts!
Note: This is enabled by default on an opt-out basis. Set either
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Tuple containing:
- Boolean denoting whether telemetry is enabled or disabled.
- Boolean denoting whether telemetry is explicitly enabled or not.
"""
# Only enable if explicitly opted in
axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true")
# In the distributed setting, check whether we're running on rank 0
if not is_main_process():
return False, False
# Respect DO_NOT_TRACK as an override even if telemetry is enabled
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true")
# Parse relevant env vars and fill opt-out default values
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
return axolotl_telemetry and not do_not_track
if axolotl_do_not_track is None:
axolotl_do_not_track = "0"
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
# If explicitly enabled, we'll disable the telemetry warning message
explicit_enabled = axolotl_do_not_track in ["0", "false"]
return enabled, explicit_enabled
def _load_whitelist(self) -> dict:
"""Load organization/model whitelist"""
@@ -146,9 +174,7 @@ class TelemetryManager:
for path in Path(error).parents:
sanitized = sanitized.replace(str(path), "")
except (ValueError, RuntimeError) as e:
# ValueError: Invalid path format
# RuntimeError: Other path parsing errors
logger.debug(f"Could not parse path in error message: {e}")
LOG.debug(f"Could not parse path in error message: {e}")
return sanitized
@@ -167,95 +193,38 @@ class TelemetryManager:
return {
"os": platform.system(),
"python_version": platform.python_version(),
"pytorch_version": torch.__version__,
"transformers_version": transformers.__version__,
"axolotl_version": axolotl.__version__,
"cpu_count": psutil.cpu_count(),
"memory_total": psutil.virtual_memory().total,
"gpu_count": len(gpu_info),
"gpu_info": gpu_info,
}
def track_event(self, event_type: str, properties: dict[str, Any]):
def track_event(self, event_type: str, properties: dict[str, Any] | None = None):
"""Track a telemetry event"""
if not self.enabled:
return
if properties is None:
properties = {}
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try:
# Get system info first - most likely source of errors
system_info = self._get_system_info()
LOG.warning(f"*** Sending telemetry for {event_type} ***")
# Send event via PostHog
try:
posthog.capture(
distinct_id=self.config.distinct_id,
event=event_type,
properties={
"run_id": self.run_id,
"system_info": system_info,
**properties,
},
)
except Exception as e:
logger.warning(f"Failed to send telemetry event: {e}")
except (RuntimeError, OSError) as e:
logger.warning(f"Failed to collect system info for telemetry: {e}")
except TypeError as e:
logger.warning(f"Invalid property type in telemetry event: {e}")
except AttributeError as e:
logger.warning(f"Failed to access system attribute for telemetry: {e}")
@contextmanager
def track_training(self, config: dict[str, Any]):
"""Context manager to track training run"""
if not self.enabled:
yield
return
# Track training start
sanitized_config = {
k: v
for k, v in config.items()
if not any(p in k.lower() for p in ["path", "dir", "file"])
}
self.track_event("training_start", {"config": sanitized_config})
try:
yield
# Track successful completion
self.track_event("training_complete", {})
except Exception as e:
# Track error
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
raise
def track_model_load(self, model_config: ModelConfig):
"""Track model loading and configuration"""
if not self.enabled or not self._is_whitelisted(model_config.base_model):
return
self.track_event(
"model_load",
{
"model_config": model_config.to_dict(),
"system_info": self._get_system_info(),
},
)
def track_training_metrics(self, metrics: dict):
"""Track training progress metrics"""
if not self.enabled:
return
self.track_event(
"training_metrics",
{
"duration": metrics.get("duration"),
"peak_memory": metrics.get("peak_memory"),
"steps_completed": metrics.get("steps_completed"),
"current_loss": metrics.get("loss"),
"learning_rate": metrics.get("learning_rate"),
},
)
posthog.capture(
distinct_id=self.run_id,
event=event_type,
properties={
"system_info": self.system_info,
**properties,
},
)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Failed to send telemetry event: {e}")
def shutdown(self):
"""Ensure all queued events are processed before shutdown"""
@@ -263,6 +232,49 @@ class TelemetryManager:
posthog.flush()
def init_telemetry_manager() -> TelemetryManager:
"""Initialize telemetry system"""
return TelemetryManager(TelemetryConfig())
ERROR_HANDLED = False
def track_errors(func: Callable) -> Callable:
"""Decorator to track errors in a function"""
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True
# Get function module path
module = getmodule(func)
module_path = (
f"{module.__name__}.{func.__name__}" if module else func.__name__
)
# Get stack trace
stack_trace = "".join(
traceback.format_exception(
type(exception), exception, exception.__traceback__
)
)
# Send error telemetry
telemetry_manager.track_event(
event_type=f"{module_path}-error",
properties={
"exception": str(exception),
"stack_trace": stack_trace,
},
)
raise
return wrapper

View File

@@ -3,6 +3,7 @@ organizations:
- "huggingface"
- "nvidia"
- "facebook"
- "mistralai"
- "briaai"
- "unsloth"
- "NousResearch"

View File

@@ -10,6 +10,7 @@ from contextlib import ExitStack
from pathlib import Path
from typing import Any, Dict
from axolotl.telemetry.manager import track_errors
import torch
import transformers.modelcard
from accelerate.utils import save_fsdp_model
@@ -47,6 +48,7 @@ except ImportError:
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def setup_model_and_tokenizer(
cfg: DictDefault,
@@ -64,7 +66,10 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_cocnfig or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
@@ -83,6 +88,14 @@ def setup_model_and_tokenizer(
if model.generation_config is not None:
model.generation_config.do_sample = True
TELEMETRY_MANAGER.track_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.track_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -527,6 +540,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
)
@track_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
@@ -565,10 +579,12 @@ def train(
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
setup_signal_handler(cfg, model, safe_serialization)
setup_model_card(cfg)
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
TELEMETRY_MANAGER.track_event(event_type="train-start")
execute_training(cfg, trainer, resume_from_checkpoint)
TELEMETRY_MANAGER.track_event(event_type="train-end")
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)

View File

@@ -1259,7 +1259,7 @@ class AxolotlInputConfig(
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
"""Wrapper to validate GPU capabilities with the config options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities

View File

@@ -1,71 +1,65 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name
import os
from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
from axolotl.telemetry import TelemetryManager
@pytest.fixture
def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing"""
whitelist_content = {
"organizations": ["meta", "mistral"],
"models": ["llama", "mistral-7b"],
"organizations": ["meta-llama", "mistralai"],
}
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w") as f:
with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f)
return str(whitelist_file)
@pytest.fixture
def config(mock_whitelist):
"""Create a TelemetryConfig with test settings"""
return TelemetryConfig(
host="https://test.posthog.com", whitelist_path=mock_whitelist
)
@pytest.fixture
def manager(config):
def manager():
"""Create a TelemetryManager instance with mocked PostHog"""
with patch("posthog.capture"):
return TelemetryManager(config)
return TelemetryManager()
def test_telemetry_disabled_by_default():
"""Test that telemetry is disabled by default"""
manager = TelemetryManager(TelemetryConfig())
manager = TelemetryManager()
assert not manager.enabled
def test_telemetry_opt_in():
"""Test that telemetry can be enabled via environment variable"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
manager = TelemetryManager(TelemetryConfig())
manager = TelemetryManager()
assert manager.enabled
def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
manager = TelemetryManager(TelemetryConfig())
manager = TelemetryManager()
assert not manager.enabled
# pylint: disable=protected-access
def test_whitelist_checking(manager):
"""Test model whitelist functionality"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
# Should match organization
assert manager._is_whitelisted("meta/llama-7b")
# Should match model name
assert manager._is_whitelisted("mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("meta/Llama-7b")
# Should match organization
assert manager._is_whitelisted("meta-llama/llama-7b")
# Should match model name
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("meta/Llama-7b")
def test_event_tracking(manager):
@@ -81,33 +75,6 @@ def test_event_tracking(manager):
assert "system_info" in mock_capture.call_args[1]["properties"]
def test_model_tracking(manager):
"""Test model load tracking"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
model_config = ModelConfig(
base_model="meta/llama-7b",
model_type="decoder",
hidden_size=4096,
num_layers=32,
num_attention_heads=32,
tokenizer_config={},
flash_attention=True,
quantization_config=None,
training_approach="lora",
)
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_model_load(model_config)
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "model_load"
assert (
mock_capture.call_args[1]["properties"]["model_config"]
== model_config.to_dict()
)
def test_training_context(manager):
"""Test training context manager"""
config = {"model": "llama", "batch_size": 8}
@@ -141,6 +108,7 @@ def test_training_error(manager):
assert "training_error" in events
# pylint: disable=protected-access
def test_path_sanitization(manager):
"""Test path sanitization"""
path = "/home/user/sensitive/data.txt"
@@ -149,6 +117,7 @@ def test_path_sanitization(manager):
assert "/home/user" not in sanitized
# pylint: disable=protected-access
def test_error_sanitization(manager):
"""Test error message sanitization"""
error = "Failed to load /home/user/sensitive/data.txt: File not found"