progress on telemetry: config load, process, model load, train start / end, error tracking
This commit is contained in:
6
docs/telemetry.qmd
Normal file
6
docs/telemetry.qmd
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
title: Telemetry
|
||||||
|
description: A description of the opt-out telemetry implementation in Axolotl.
|
||||||
|
---
|
||||||
|
|
||||||
|
TODO.
|
||||||
@@ -14,6 +14,8 @@ import yaml
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.telemetry import TelemetryManager
|
||||||
|
from axolotl.telemetry.manager import track_errors
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
@@ -28,6 +30,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
"""
|
"""
|
||||||
@@ -159,6 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
|
|||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def load_cfg(
|
def load_cfg(
|
||||||
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||||
) -> DictDefault:
|
) -> DictDefault:
|
||||||
@@ -192,6 +197,8 @@ def load_cfg(
|
|||||||
temp_file.close()
|
temp_file.close()
|
||||||
cfg.axolotl_config_path = temp_file.name
|
cfg.axolotl_config_path = temp_file.name
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg)
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
cfg_keys = cfg.keys()
|
cfg_keys = cfg.keys()
|
||||||
@@ -233,4 +240,6 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ def load_lora(
|
|||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def load_adapter(
|
def load_adapter(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ class ModelLoader:
|
|||||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||||
"""Load and prepare the model with all configurations and patches.
|
"""Load and prepare the model with all configurations and patches.
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||||
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
|
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
|
||||||
|
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ def modify_tokenizer_files(
|
|||||||
return tokenizer_dir
|
return tokenizer_dir
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
"""Load and configure the tokenizer based on the provided config."""
|
"""Load and configure the tokenizer based on the provided config."""
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
"""Init for axolotl.telemetry module."""
|
"""Init for axolotl.telemetry module."""
|
||||||
|
|
||||||
from .manager import (
|
from .manager import TelemetryConfig, TelemetryManager
|
||||||
ModelConfig,
|
|
||||||
TelemetryConfig,
|
|
||||||
TelemetryManager,
|
|
||||||
init_telemetry_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TelemetryConfig",
|
"TelemetryConfig",
|
||||||
"TelemetryManager",
|
"TelemetryManager",
|
||||||
"ModelConfig",
|
|
||||||
"init_telemetry_manager",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,70 +1,46 @@
|
|||||||
"""Telemetry manager and associated utilities."""
|
"""Telemetry manager and associated utilities."""
|
||||||
|
|
||||||
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import wraps
|
||||||
|
from inspect import getmodule
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
import posthog
|
import posthog
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import axolotl
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
||||||
|
ENABLED_WARNING_SLEEP_SECONDS = 10
|
||||||
|
ENABLED_WARNING = (
|
||||||
@dataclass
|
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
||||||
class ModelConfig:
|
"- Which models and configurations are most commonly used\n"
|
||||||
"""Tracked model configuration details"""
|
"- What hardware setups need to be supported\n"
|
||||||
|
"- Where users encounter errors\n\n"
|
||||||
base_model: str
|
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
||||||
model_type: str
|
"To disable telemetry, set either:\n"
|
||||||
hidden_size: int
|
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
||||||
num_layers: int
|
"- DO_NOT_TRACK=1 (Global standard)\n\n"
|
||||||
num_attention_heads: int
|
"To remove this warning and continue with telemetry enabled,"
|
||||||
tokenizer_config: dict
|
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
||||||
flash_attention: bool
|
"No personally identifiable information is collected."
|
||||||
quantization_config: dict | None
|
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
||||||
training_approach: str # 'lora', 'qlora', 'full_finetune'
|
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
||||||
|
)
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: dict) -> "ModelConfig":
|
|
||||||
"""Create from Axolotl config dict"""
|
|
||||||
return cls(
|
|
||||||
base_model=config.get("base_model", ""),
|
|
||||||
model_type=config.get("model_type", ""),
|
|
||||||
hidden_size=config.get("hidden_size", 0),
|
|
||||||
num_layers=config.get("num_layers", 0),
|
|
||||||
num_attention_heads=config.get("num_attention_heads", 0),
|
|
||||||
tokenizer_config=config.get("tokenizer", {}),
|
|
||||||
flash_attention=config.get("flash_attention", False),
|
|
||||||
quantization_config=config.get("quantization", None),
|
|
||||||
training_approach=config.get("training_approach", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert to PostHog-compatible dict"""
|
|
||||||
return {
|
|
||||||
"base_model": self.base_model,
|
|
||||||
"model_type": self.model_type,
|
|
||||||
"architecture": {
|
|
||||||
"hidden_size": self.hidden_size,
|
|
||||||
"num_layers": self.num_layers,
|
|
||||||
"num_attention_heads": self.num_attention_heads,
|
|
||||||
},
|
|
||||||
"optimizations": {
|
|
||||||
"flash_attention": self.flash_attention,
|
|
||||||
"quantization": self.quantization_config is not None,
|
|
||||||
"quantization_config": self.quantization_config,
|
|
||||||
},
|
|
||||||
"training_approach": self.training_approach,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -76,43 +52,95 @@ class TelemetryConfig:
|
|||||||
batch_size: int = 10
|
batch_size: int = 10
|
||||||
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
||||||
retention_days: int = 365
|
retention_days: int = 365
|
||||||
distinct_id: str = str(uuid.uuid4())
|
|
||||||
schema_version: str = "0.1.0"
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
class TelemetryManager:
|
||||||
"""Manages telemetry collection and transmission"""
|
"""Manages telemetry collection and transmission"""
|
||||||
|
|
||||||
def __init__(self, config: TelemetryConfig):
|
_instance = None
|
||||||
"""
|
_initialized = False
|
||||||
Telemetry manager constructor.
|
|
||||||
|
|
||||||
Args:
|
def __new__(cls):
|
||||||
config: Telemetry configuration object.
|
|
||||||
"""
|
"""
|
||||||
self.config = config
|
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||||
self.run_id = str(uuid.uuid4())
|
it doesn't already exist.
|
||||||
self.enabled = self._check_telemetry_enabled()
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Telemetry manager initializer"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
|
# Warn about telemetry collection
|
||||||
|
if not self.explicit_enable:
|
||||||
|
LOG.warning(ENABLED_WARNING)
|
||||||
|
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
||||||
|
|
||||||
|
self.config = TelemetryConfig()
|
||||||
|
self.run_id = str(uuid.uuid4())
|
||||||
self.whitelist = self._load_whitelist()
|
self.whitelist = self._load_whitelist()
|
||||||
|
self.system_info = self._get_system_info()
|
||||||
self._init_posthog()
|
self._init_posthog()
|
||||||
|
|
||||||
def _check_telemetry_enabled(self) -> bool:
|
# Register shutdown method to flush posthog telemetry
|
||||||
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "TelemetryManager":
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = TelemetryManager()
|
||||||
|
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
|
||||||
"""
|
"""
|
||||||
Check if telemetry is enabled based on environment variables.
|
Check if telemetry is enabled based on environment variables. We also check
|
||||||
|
whether this is the main process (for the distributed setting and to avoid
|
||||||
|
sending duplicate PostHog events per GPU).
|
||||||
|
|
||||||
Note: This is enabled on an opt-in basis. Please consider setting
|
Note: This is enabled by default on an opt-out basis. Set either
|
||||||
`AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're
|
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
|
||||||
using so we can focus our engineering efforts!
|
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- Boolean denoting whether telemetry is enabled or disabled.
|
||||||
|
- Boolean denoting whether telemetry is explicitly enabled or not.
|
||||||
"""
|
"""
|
||||||
# Only enable if explicitly opted in
|
# In the distributed setting, check whether we're running on rank 0
|
||||||
axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true")
|
if not is_main_process():
|
||||||
|
return False, False
|
||||||
|
|
||||||
# Respect DO_NOT_TRACK as an override even if telemetry is enabled
|
# Parse relevant env vars and fill opt-out default values
|
||||||
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true")
|
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||||
|
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||||
|
|
||||||
return axolotl_telemetry and not do_not_track
|
if axolotl_do_not_track is None:
|
||||||
|
axolotl_do_not_track = "0"
|
||||||
|
|
||||||
|
if do_not_track is None:
|
||||||
|
do_not_track = "0"
|
||||||
|
|
||||||
|
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||||
|
enabled = axolotl_do_not_track.lower() not in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
) and do_not_track.lower() not in ("1", "true")
|
||||||
|
|
||||||
|
# If explicitly enabled, we'll disable the telemetry warning message
|
||||||
|
explicit_enabled = axolotl_do_not_track in ["0", "false"]
|
||||||
|
|
||||||
|
return enabled, explicit_enabled
|
||||||
|
|
||||||
def _load_whitelist(self) -> dict:
|
def _load_whitelist(self) -> dict:
|
||||||
"""Load organization/model whitelist"""
|
"""Load organization/model whitelist"""
|
||||||
@@ -146,9 +174,7 @@ class TelemetryManager:
|
|||||||
for path in Path(error).parents:
|
for path in Path(error).parents:
|
||||||
sanitized = sanitized.replace(str(path), "")
|
sanitized = sanitized.replace(str(path), "")
|
||||||
except (ValueError, RuntimeError) as e:
|
except (ValueError, RuntimeError) as e:
|
||||||
# ValueError: Invalid path format
|
LOG.debug(f"Could not parse path in error message: {e}")
|
||||||
# RuntimeError: Other path parsing errors
|
|
||||||
logger.debug(f"Could not parse path in error message: {e}")
|
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@@ -167,95 +193,38 @@ class TelemetryManager:
|
|||||||
return {
|
return {
|
||||||
"os": platform.system(),
|
"os": platform.system(),
|
||||||
"python_version": platform.python_version(),
|
"python_version": platform.python_version(),
|
||||||
|
"pytorch_version": torch.__version__,
|
||||||
|
"transformers_version": transformers.__version__,
|
||||||
|
"axolotl_version": axolotl.__version__,
|
||||||
"cpu_count": psutil.cpu_count(),
|
"cpu_count": psutil.cpu_count(),
|
||||||
"memory_total": psutil.virtual_memory().total,
|
"memory_total": psutil.virtual_memory().total,
|
||||||
"gpu_count": len(gpu_info),
|
"gpu_count": len(gpu_info),
|
||||||
"gpu_info": gpu_info,
|
"gpu_info": gpu_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
def track_event(self, event_type: str, properties: dict[str, Any]):
|
def track_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||||
"""Track a telemetry event"""
|
"""Track a telemetry event"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if properties is None:
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||||
try:
|
try:
|
||||||
# Get system info first - most likely source of errors
|
LOG.warning(f"*** Sending telemetry for {event_type} ***")
|
||||||
system_info = self._get_system_info()
|
|
||||||
|
|
||||||
# Send event via PostHog
|
# Send event via PostHog
|
||||||
try:
|
posthog.capture(
|
||||||
posthog.capture(
|
distinct_id=self.run_id,
|
||||||
distinct_id=self.config.distinct_id,
|
event=event_type,
|
||||||
event=event_type,
|
properties={
|
||||||
properties={
|
"system_info": self.system_info,
|
||||||
"run_id": self.run_id,
|
**properties,
|
||||||
"system_info": system_info,
|
},
|
||||||
**properties,
|
)
|
||||||
},
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
)
|
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to send telemetry event: {e}")
|
|
||||||
except (RuntimeError, OSError) as e:
|
|
||||||
logger.warning(f"Failed to collect system info for telemetry: {e}")
|
|
||||||
except TypeError as e:
|
|
||||||
logger.warning(f"Invalid property type in telemetry event: {e}")
|
|
||||||
except AttributeError as e:
|
|
||||||
logger.warning(f"Failed to access system attribute for telemetry: {e}")
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def track_training(self, config: dict[str, Any]):
|
|
||||||
"""Context manager to track training run"""
|
|
||||||
if not self.enabled:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
# Track training start
|
|
||||||
sanitized_config = {
|
|
||||||
k: v
|
|
||||||
for k, v in config.items()
|
|
||||||
if not any(p in k.lower() for p in ["path", "dir", "file"])
|
|
||||||
}
|
|
||||||
|
|
||||||
self.track_event("training_start", {"config": sanitized_config})
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
# Track successful completion
|
|
||||||
self.track_event("training_complete", {})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Track error
|
|
||||||
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
|
|
||||||
raise
|
|
||||||
|
|
||||||
def track_model_load(self, model_config: ModelConfig):
|
|
||||||
"""Track model loading and configuration"""
|
|
||||||
if not self.enabled or not self._is_whitelisted(model_config.base_model):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.track_event(
|
|
||||||
"model_load",
|
|
||||||
{
|
|
||||||
"model_config": model_config.to_dict(),
|
|
||||||
"system_info": self._get_system_info(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def track_training_metrics(self, metrics: dict):
|
|
||||||
"""Track training progress metrics"""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.track_event(
|
|
||||||
"training_metrics",
|
|
||||||
{
|
|
||||||
"duration": metrics.get("duration"),
|
|
||||||
"peak_memory": metrics.get("peak_memory"),
|
|
||||||
"steps_completed": metrics.get("steps_completed"),
|
|
||||||
"current_loss": metrics.get("loss"),
|
|
||||||
"learning_rate": metrics.get("learning_rate"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Ensure all queued events are processed before shutdown"""
|
"""Ensure all queued events are processed before shutdown"""
|
||||||
@@ -263,6 +232,49 @@ class TelemetryManager:
|
|||||||
posthog.flush()
|
posthog.flush()
|
||||||
|
|
||||||
|
|
||||||
def init_telemetry_manager() -> TelemetryManager:
|
ERROR_HANDLED = False
|
||||||
"""Initialize telemetry system"""
|
|
||||||
return TelemetryManager(TelemetryConfig())
|
|
||||||
|
def track_errors(func: Callable) -> Callable:
|
||||||
|
"""Decorator to track errors in a function"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
telemetry_manager = TelemetryManager.get_instance()
|
||||||
|
if not telemetry_manager.enabled:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as exception:
|
||||||
|
# Only track if we're not already handling an error. This prevents us from
|
||||||
|
# capturing an error more than once in nested decorated function calls.
|
||||||
|
global ERROR_HANDLED # pylint: disable=global-statement
|
||||||
|
if not ERROR_HANDLED:
|
||||||
|
ERROR_HANDLED = True
|
||||||
|
|
||||||
|
# Get function module path
|
||||||
|
module = getmodule(func)
|
||||||
|
module_path = (
|
||||||
|
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get stack trace
|
||||||
|
stack_trace = "".join(
|
||||||
|
traceback.format_exception(
|
||||||
|
type(exception), exception, exception.__traceback__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send error telemetry
|
||||||
|
telemetry_manager.track_event(
|
||||||
|
event_type=f"{module_path}-error",
|
||||||
|
properties={
|
||||||
|
"exception": str(exception),
|
||||||
|
"stack_trace": stack_trace,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ organizations:
|
|||||||
- "huggingface"
|
- "huggingface"
|
||||||
- "nvidia"
|
- "nvidia"
|
||||||
- "facebook"
|
- "facebook"
|
||||||
|
- "mistralai"
|
||||||
- "briaai"
|
- "briaai"
|
||||||
- "unsloth"
|
- "unsloth"
|
||||||
- "NousResearch"
|
- "NousResearch"
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from contextlib import ExitStack
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from axolotl.telemetry.manager import track_errors
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
@@ -47,6 +48,7 @@ except ImportError:
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
|
|
||||||
def setup_model_and_tokenizer(
|
def setup_model_and_tokenizer(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -64,7 +66,10 @@ def setup_model_and_tokenizer(
|
|||||||
`None`), and processor (if multimodal, else `None`).
|
`None`), and processor (if multimodal, else `None`).
|
||||||
"""
|
"""
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.debug(
|
||||||
|
f"loading tokenizer... {cfg.tokenizer_cocnfig or cfg.base_model_config}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
# Load processor for multimodal models if needed
|
# Load processor for multimodal models if needed
|
||||||
@@ -83,6 +88,14 @@ def setup_model_and_tokenizer(
|
|||||||
if model.generation_config is not None:
|
if model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(
|
||||||
|
event_type="model-load", properties=model.config.to_dict()
|
||||||
|
)
|
||||||
|
if peft_config:
|
||||||
|
TELEMETRY_MANAGER.track_event(
|
||||||
|
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
# Apply freezing if specified
|
# Apply freezing if specified
|
||||||
if cfg.unfrozen_parameters:
|
if cfg.unfrozen_parameters:
|
||||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||||
@@ -527,6 +540,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||||
@@ -565,10 +579,12 @@ def train(
|
|||||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
TELEMETRY_MANAGER.track_event(event_type="train-start")
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="train-end")
|
||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
|||||||
@@ -1259,7 +1259,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""Wrapper to validate GPU capabilities with the config options"""
|
||||||
|
|
||||||
capabilities: GPUCapabilities
|
capabilities: GPUCapabilities
|
||||||
env_capabilities: EnvCapabilities
|
env_capabilities: EnvCapabilities
|
||||||
|
|||||||
@@ -1,71 +1,65 @@
|
|||||||
|
"""Tests for TelemetryManager class and utilities"""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
|
from axolotl.telemetry import TelemetryManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_whitelist(tmp_path):
|
def mock_whitelist(tmp_path):
|
||||||
"""Create a temporary whitelist file for testing"""
|
"""Create a temporary whitelist file for testing"""
|
||||||
whitelist_content = {
|
whitelist_content = {
|
||||||
"organizations": ["meta", "mistral"],
|
"organizations": ["meta-llama", "mistralai"],
|
||||||
"models": ["llama", "mistral-7b"],
|
|
||||||
}
|
}
|
||||||
whitelist_file = tmp_path / "whitelist.yaml"
|
whitelist_file = tmp_path / "whitelist.yaml"
|
||||||
with open(whitelist_file, "w") as f:
|
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||||
yaml.dump(whitelist_content, f)
|
yaml.dump(whitelist_content, f)
|
||||||
return str(whitelist_file)
|
return str(whitelist_file)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def config(mock_whitelist):
|
def manager():
|
||||||
"""Create a TelemetryConfig with test settings"""
|
|
||||||
return TelemetryConfig(
|
|
||||||
host="https://test.posthog.com", whitelist_path=mock_whitelist
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def manager(config):
|
|
||||||
"""Create a TelemetryManager instance with mocked PostHog"""
|
"""Create a TelemetryManager instance with mocked PostHog"""
|
||||||
with patch("posthog.capture"):
|
with patch("posthog.capture"):
|
||||||
return TelemetryManager(config)
|
return TelemetryManager()
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_by_default():
|
def test_telemetry_disabled_by_default():
|
||||||
"""Test that telemetry is disabled by default"""
|
"""Test that telemetry is disabled by default"""
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_opt_in():
|
def test_telemetry_opt_in():
|
||||||
"""Test that telemetry can be enabled via environment variable"""
|
"""Test that telemetry can be enabled via environment variable"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager()
|
||||||
assert manager.enabled
|
assert manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_do_not_track_override():
|
def test_do_not_track_override():
|
||||||
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
def test_whitelist_checking(manager):
|
def test_whitelist_checking(manager):
|
||||||
"""Test model whitelist functionality"""
|
"""Test model whitelist functionality"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
# Should match organization
|
||||||
# Should match organization
|
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||||
assert manager._is_whitelisted("meta/llama-7b")
|
# Should match model name
|
||||||
# Should match model name
|
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||||
assert manager._is_whitelisted("mistral-7b-instruct")
|
# Should not match
|
||||||
# Should not match
|
assert not manager._is_whitelisted("unknown/model")
|
||||||
assert not manager._is_whitelisted("unknown/model")
|
# Should handle case insensitively
|
||||||
# Should handle case insensitively
|
assert manager._is_whitelisted("meta/Llama-7b")
|
||||||
assert manager._is_whitelisted("meta/Llama-7b")
|
|
||||||
|
|
||||||
|
|
||||||
def test_event_tracking(manager):
|
def test_event_tracking(manager):
|
||||||
@@ -81,33 +75,6 @@ def test_event_tracking(manager):
|
|||||||
assert "system_info" in mock_capture.call_args[1]["properties"]
|
assert "system_info" in mock_capture.call_args[1]["properties"]
|
||||||
|
|
||||||
|
|
||||||
def test_model_tracking(manager):
|
|
||||||
"""Test model load tracking"""
|
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
|
||||||
model_config = ModelConfig(
|
|
||||||
base_model="meta/llama-7b",
|
|
||||||
model_type="decoder",
|
|
||||||
hidden_size=4096,
|
|
||||||
num_layers=32,
|
|
||||||
num_attention_heads=32,
|
|
||||||
tokenizer_config={},
|
|
||||||
flash_attention=True,
|
|
||||||
quantization_config=None,
|
|
||||||
training_approach="lora",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
|
||||||
manager.enabled = True
|
|
||||||
manager.track_model_load(model_config)
|
|
||||||
|
|
||||||
assert mock_capture.called
|
|
||||||
assert mock_capture.call_args[1]["event"] == "model_load"
|
|
||||||
assert (
|
|
||||||
mock_capture.call_args[1]["properties"]["model_config"]
|
|
||||||
== model_config.to_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_training_context(manager):
|
def test_training_context(manager):
|
||||||
"""Test training context manager"""
|
"""Test training context manager"""
|
||||||
config = {"model": "llama", "batch_size": 8}
|
config = {"model": "llama", "batch_size": 8}
|
||||||
@@ -141,6 +108,7 @@ def test_training_error(manager):
|
|||||||
assert "training_error" in events
|
assert "training_error" in events
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
def test_path_sanitization(manager):
|
def test_path_sanitization(manager):
|
||||||
"""Test path sanitization"""
|
"""Test path sanitization"""
|
||||||
path = "/home/user/sensitive/data.txt"
|
path = "/home/user/sensitive/data.txt"
|
||||||
@@ -149,6 +117,7 @@ def test_path_sanitization(manager):
|
|||||||
assert "/home/user" not in sanitized
|
assert "/home/user" not in sanitized
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
def test_error_sanitization(manager):
|
def test_error_sanitization(manager):
|
||||||
"""Test error message sanitization"""
|
"""Test error message sanitization"""
|
||||||
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
||||||
|
|||||||
Reference in New Issue
Block a user