progress on telemetry: config load, process, model load, train start / end, error tracking
This commit is contained in:
6
docs/telemetry.qmd
Normal file
6
docs/telemetry.qmd
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
title: Telemetry
|
||||||
|
description: A description of the opt-out telemetry implementation in Axolotl.
|
||||||
|
---
|
||||||
|
|
||||||
|
TODO.
|
||||||
@@ -14,6 +14,8 @@ import yaml
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.telemetry import TelemetryManager
|
||||||
|
from axolotl.telemetry.manager import track_errors
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
@@ -27,6 +29,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
"""
|
"""
|
||||||
@@ -152,6 +156,7 @@ def prepare_plugins(cfg: DictDefault):
|
|||||||
plugin_manager.register(plugin_name)
|
plugin_manager.register(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
||||||
"""
|
"""
|
||||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
@@ -172,6 +177,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg)
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
cfg_keys = cfg.keys()
|
cfg_keys = cfg.keys()
|
||||||
@@ -214,4 +221,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
"""Init for axolotl.telemetry module."""
|
"""Init for axolotl.telemetry module."""
|
||||||
|
|
||||||
from .manager import (
|
from .manager import TelemetryConfig, TelemetryManager
|
||||||
ModelConfig,
|
|
||||||
TelemetryConfig,
|
|
||||||
TelemetryManager,
|
|
||||||
init_telemetry_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TelemetryConfig",
|
"TelemetryConfig",
|
||||||
"TelemetryManager",
|
"TelemetryManager",
|
||||||
"ModelConfig",
|
|
||||||
"init_telemetry_manager",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,70 +1,46 @@
|
|||||||
"""Telemetry manager and associated utilities."""
|
"""Telemetry manager and associated utilities."""
|
||||||
|
|
||||||
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import wraps
|
||||||
|
from inspect import getmodule
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
import posthog
|
import posthog
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import axolotl
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
||||||
|
ENABLED_WARNING_SLEEP_SECONDS = 10
|
||||||
|
ENABLED_WARNING = (
|
||||||
@dataclass
|
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
||||||
class ModelConfig:
|
"- Which models and configurations are most commonly used\n"
|
||||||
"""Tracked model configuration details"""
|
"- What hardware setups need to be supported\n"
|
||||||
|
"- Where users encounter errors\n\n"
|
||||||
base_model: str
|
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
||||||
model_type: str
|
"To disable telemetry, set either:\n"
|
||||||
hidden_size: int
|
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
||||||
num_layers: int
|
"- DO_NOT_TRACK=1 (Global standard)\n\n"
|
||||||
num_attention_heads: int
|
"To remove this warning and continue with telemetry enabled,"
|
||||||
tokenizer_config: dict
|
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
||||||
flash_attention: bool
|
"No personally identifiable information is collected."
|
||||||
quantization_config: dict | None
|
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
||||||
training_approach: str # 'lora', 'qlora', 'full_finetune'
|
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
||||||
|
)
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: dict) -> "ModelConfig":
|
|
||||||
"""Create from Axolotl config dict"""
|
|
||||||
return cls(
|
|
||||||
base_model=config.get("base_model", ""),
|
|
||||||
model_type=config.get("model_type", ""),
|
|
||||||
hidden_size=config.get("hidden_size", 0),
|
|
||||||
num_layers=config.get("num_layers", 0),
|
|
||||||
num_attention_heads=config.get("num_attention_heads", 0),
|
|
||||||
tokenizer_config=config.get("tokenizer", {}),
|
|
||||||
flash_attention=config.get("flash_attention", False),
|
|
||||||
quantization_config=config.get("quantization", None),
|
|
||||||
training_approach=config.get("training_approach", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert to PostHog-compatible dict"""
|
|
||||||
return {
|
|
||||||
"base_model": self.base_model,
|
|
||||||
"model_type": self.model_type,
|
|
||||||
"architecture": {
|
|
||||||
"hidden_size": self.hidden_size,
|
|
||||||
"num_layers": self.num_layers,
|
|
||||||
"num_attention_heads": self.num_attention_heads,
|
|
||||||
},
|
|
||||||
"optimizations": {
|
|
||||||
"flash_attention": self.flash_attention,
|
|
||||||
"quantization": self.quantization_config is not None,
|
|
||||||
"quantization_config": self.quantization_config,
|
|
||||||
},
|
|
||||||
"training_approach": self.training_approach,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -76,43 +52,95 @@ class TelemetryConfig:
|
|||||||
batch_size: int = 10
|
batch_size: int = 10
|
||||||
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
||||||
retention_days: int = 365
|
retention_days: int = 365
|
||||||
distinct_id: str = str(uuid.uuid4())
|
|
||||||
schema_version: str = "0.1.0"
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
class TelemetryManager:
|
||||||
"""Manages telemetry collection and transmission"""
|
"""Manages telemetry collection and transmission"""
|
||||||
|
|
||||||
def __init__(self, config: TelemetryConfig):
|
_instance = None
|
||||||
"""
|
_initialized = False
|
||||||
Telemetry manager constructor.
|
|
||||||
|
|
||||||
Args:
|
def __new__(cls):
|
||||||
config: Telemetry configuration object.
|
|
||||||
"""
|
"""
|
||||||
self.config = config
|
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||||
self.run_id = str(uuid.uuid4())
|
it doesn't already exist.
|
||||||
self.enabled = self._check_telemetry_enabled()
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Telemetry manager initializer"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
|
# Warn about telemetry collection
|
||||||
|
if not self.explicit_enable:
|
||||||
|
LOG.warning(ENABLED_WARNING)
|
||||||
|
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
||||||
|
|
||||||
|
self.config = TelemetryConfig()
|
||||||
|
self.run_id = str(uuid.uuid4())
|
||||||
self.whitelist = self._load_whitelist()
|
self.whitelist = self._load_whitelist()
|
||||||
|
self.system_info = self._get_system_info()
|
||||||
self._init_posthog()
|
self._init_posthog()
|
||||||
|
|
||||||
def _check_telemetry_enabled(self) -> bool:
|
# Register shutdown method to flush posthog telemetry
|
||||||
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "TelemetryManager":
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = TelemetryManager()
|
||||||
|
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
|
||||||
"""
|
"""
|
||||||
Check if telemetry is enabled based on environment variables.
|
Check if telemetry is enabled based on environment variables. We also check
|
||||||
|
whether this is the main process (for the distributed setting and to avoid
|
||||||
|
sending duplicate PostHog events per GPU).
|
||||||
|
|
||||||
Note: This is enabled on an opt-in basis. Please consider setting
|
Note: This is enabled by default on an opt-out basis. Set either
|
||||||
`AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're
|
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
|
||||||
using so we can focus our engineering efforts!
|
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- Boolean denoting whether telemetry is enabled or disabled.
|
||||||
|
- Boolean denoting whether telemetry is explicitly enabled or not.
|
||||||
"""
|
"""
|
||||||
# Only enable if explicitly opted in
|
# In the distributed setting, check whether we're running on rank 0
|
||||||
axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true")
|
if not is_main_process():
|
||||||
|
return False, False
|
||||||
|
|
||||||
# Respect DO_NOT_TRACK as an override even if telemetry is enabled
|
# Parse relevant env vars and fill opt-out default values
|
||||||
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true")
|
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||||
|
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||||
|
|
||||||
return axolotl_telemetry and not do_not_track
|
if axolotl_do_not_track is None:
|
||||||
|
axolotl_do_not_track = "0"
|
||||||
|
|
||||||
|
if do_not_track is None:
|
||||||
|
do_not_track = "0"
|
||||||
|
|
||||||
|
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||||
|
enabled = axolotl_do_not_track.lower() not in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
) and do_not_track.lower() not in ("1", "true")
|
||||||
|
|
||||||
|
# If explicitly enabled, we'll disable the telemetry warning message
|
||||||
|
explicit_enabled = axolotl_do_not_track in ["0", "false"]
|
||||||
|
|
||||||
|
return enabled, explicit_enabled
|
||||||
|
|
||||||
def _load_whitelist(self) -> dict:
|
def _load_whitelist(self) -> dict:
|
||||||
"""Load organization/model whitelist"""
|
"""Load organization/model whitelist"""
|
||||||
@@ -146,9 +174,7 @@ class TelemetryManager:
|
|||||||
for path in Path(error).parents:
|
for path in Path(error).parents:
|
||||||
sanitized = sanitized.replace(str(path), "")
|
sanitized = sanitized.replace(str(path), "")
|
||||||
except (ValueError, RuntimeError) as e:
|
except (ValueError, RuntimeError) as e:
|
||||||
# ValueError: Invalid path format
|
LOG.debug(f"Could not parse path in error message: {e}")
|
||||||
# RuntimeError: Other path parsing errors
|
|
||||||
logger.debug(f"Could not parse path in error message: {e}")
|
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@@ -167,95 +193,38 @@ class TelemetryManager:
|
|||||||
return {
|
return {
|
||||||
"os": platform.system(),
|
"os": platform.system(),
|
||||||
"python_version": platform.python_version(),
|
"python_version": platform.python_version(),
|
||||||
|
"pytorch_version": torch.__version__,
|
||||||
|
"transformers_version": transformers.__version__,
|
||||||
|
"axolotl_version": axolotl.__version__,
|
||||||
"cpu_count": psutil.cpu_count(),
|
"cpu_count": psutil.cpu_count(),
|
||||||
"memory_total": psutil.virtual_memory().total,
|
"memory_total": psutil.virtual_memory().total,
|
||||||
"gpu_count": len(gpu_info),
|
"gpu_count": len(gpu_info),
|
||||||
"gpu_info": gpu_info,
|
"gpu_info": gpu_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
def track_event(self, event_type: str, properties: dict[str, Any]):
|
def track_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||||
"""Track a telemetry event"""
|
"""Track a telemetry event"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if properties is None:
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||||
try:
|
try:
|
||||||
# Get system info first - most likely source of errors
|
LOG.warning(f"*** Sending telemetry for {event_type} ***")
|
||||||
system_info = self._get_system_info()
|
|
||||||
|
|
||||||
# Send event via PostHog
|
# Send event via PostHog
|
||||||
try:
|
posthog.capture(
|
||||||
posthog.capture(
|
distinct_id=self.run_id,
|
||||||
distinct_id=self.config.distinct_id,
|
event=event_type,
|
||||||
event=event_type,
|
properties={
|
||||||
properties={
|
"system_info": self.system_info,
|
||||||
"run_id": self.run_id,
|
**properties,
|
||||||
"system_info": system_info,
|
},
|
||||||
**properties,
|
)
|
||||||
},
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
)
|
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to send telemetry event: {e}")
|
|
||||||
except (RuntimeError, OSError) as e:
|
|
||||||
logger.warning(f"Failed to collect system info for telemetry: {e}")
|
|
||||||
except TypeError as e:
|
|
||||||
logger.warning(f"Invalid property type in telemetry event: {e}")
|
|
||||||
except AttributeError as e:
|
|
||||||
logger.warning(f"Failed to access system attribute for telemetry: {e}")
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def track_training(self, config: dict[str, Any]):
|
|
||||||
"""Context manager to track training run"""
|
|
||||||
if not self.enabled:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
# Track training start
|
|
||||||
sanitized_config = {
|
|
||||||
k: v
|
|
||||||
for k, v in config.items()
|
|
||||||
if not any(p in k.lower() for p in ["path", "dir", "file"])
|
|
||||||
}
|
|
||||||
|
|
||||||
self.track_event("training_start", {"config": sanitized_config})
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
# Track successful completion
|
|
||||||
self.track_event("training_complete", {})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Track error
|
|
||||||
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
|
|
||||||
raise
|
|
||||||
|
|
||||||
def track_model_load(self, model_config: ModelConfig):
|
|
||||||
"""Track model loading and configuration"""
|
|
||||||
if not self.enabled or not self._is_whitelisted(model_config.base_model):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.track_event(
|
|
||||||
"model_load",
|
|
||||||
{
|
|
||||||
"model_config": model_config.to_dict(),
|
|
||||||
"system_info": self._get_system_info(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def track_training_metrics(self, metrics: dict):
|
|
||||||
"""Track training progress metrics"""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.track_event(
|
|
||||||
"training_metrics",
|
|
||||||
{
|
|
||||||
"duration": metrics.get("duration"),
|
|
||||||
"peak_memory": metrics.get("peak_memory"),
|
|
||||||
"steps_completed": metrics.get("steps_completed"),
|
|
||||||
"current_loss": metrics.get("loss"),
|
|
||||||
"learning_rate": metrics.get("learning_rate"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Ensure all queued events are processed before shutdown"""
|
"""Ensure all queued events are processed before shutdown"""
|
||||||
@@ -263,6 +232,49 @@ class TelemetryManager:
|
|||||||
posthog.flush()
|
posthog.flush()
|
||||||
|
|
||||||
|
|
||||||
def init_telemetry_manager() -> TelemetryManager:
|
ERROR_HANDLED = False
|
||||||
"""Initialize telemetry system"""
|
|
||||||
return TelemetryManager(TelemetryConfig())
|
|
||||||
|
def track_errors(func: Callable) -> Callable:
|
||||||
|
"""Decorator to track errors in a function"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
telemetry_manager = TelemetryManager.get_instance()
|
||||||
|
if not telemetry_manager.enabled:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as exception:
|
||||||
|
# Only track if we're not already handling an error. This prevents us from
|
||||||
|
# capturing an error more than once in nested decorated function calls.
|
||||||
|
global ERROR_HANDLED # pylint: disable=global-statement
|
||||||
|
if not ERROR_HANDLED:
|
||||||
|
ERROR_HANDLED = True
|
||||||
|
|
||||||
|
# Get function module path
|
||||||
|
module = getmodule(func)
|
||||||
|
module_path = (
|
||||||
|
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get stack trace
|
||||||
|
stack_trace = "".join(
|
||||||
|
traceback.format_exception(
|
||||||
|
type(exception), exception, exception.__traceback__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send error telemetry
|
||||||
|
telemetry_manager.track_event(
|
||||||
|
event_type=f"{module_path}-error",
|
||||||
|
properties={
|
||||||
|
"exception": str(exception),
|
||||||
|
"stack_trace": stack_trace,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ organizations:
|
|||||||
- "huggingface"
|
- "huggingface"
|
||||||
- "nvidia"
|
- "nvidia"
|
||||||
- "facebook"
|
- "facebook"
|
||||||
|
- "mistralai"
|
||||||
- "briaai"
|
- "briaai"
|
||||||
- "unsloth"
|
- "unsloth"
|
||||||
- "NousResearch"
|
- "NousResearch"
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-modu
|
|||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.telemetry import TelemetryManager
|
||||||
|
from axolotl.telemetry.manager import track_errors
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
@@ -39,13 +41,16 @@ sys.path.insert(0, src_dir)
|
|||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_cocnfig or cfg.base_model_config}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
@@ -75,7 +80,7 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load model
|
||||||
msg = "loading model"
|
msg = "loading model"
|
||||||
if cfg.adapter:
|
if cfg.adapter:
|
||||||
msg += " and peft_config..."
|
msg += " and peft_config..."
|
||||||
@@ -84,6 +89,14 @@ def train(
|
|||||||
if model.generation_config is not None:
|
if model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(
|
||||||
|
event_type="model-load", properties=model.config.to_dict()
|
||||||
|
)
|
||||||
|
if peft_config:
|
||||||
|
TELEMETRY_MANAGER.track_event(
|
||||||
|
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
model_ref = None
|
model_ref = None
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.rl and cfg.rl != "orpo":
|
||||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
@@ -91,7 +104,7 @@ def train(
|
|||||||
LOG.debug("Passing model_ref: None to RL trainer")
|
LOG.debug("Passing model_ref: None to RL trainer")
|
||||||
model_ref = None # explicit setting to None
|
model_ref = None # explicit setting to None
|
||||||
else:
|
else:
|
||||||
# load the model again for model_ref/baseline
|
# load the model again for model_ref / baseline
|
||||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
@@ -174,6 +187,8 @@ def train(
|
|||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="train-start")
|
||||||
|
|
||||||
pretrain_hooks(cfg, trainer)
|
pretrain_hooks(cfg, trainer)
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
@@ -189,6 +204,8 @@ def train(
|
|||||||
|
|
||||||
post_train_hooks(cfg, trainer)
|
post_train_hooks(cfg, trainer)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.track_event(event_type="train-end")
|
||||||
|
|
||||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
# post training
|
# post training
|
||||||
|
|||||||
@@ -1683,7 +1683,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""Wrapper to validate GPU capabilities with the config options"""
|
||||||
|
|
||||||
capabilities: GPUCapabilities
|
capabilities: GPUCapabilities
|
||||||
env_capabilities: EnvCapabilities
|
env_capabilities: EnvCapabilities
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ from axolotl.monkeypatch.multipack import (
|
|||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
)
|
)
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
|
from axolotl.telemetry.manager import track_errors
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -165,6 +166,7 @@ def load_model_config(cfg):
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
@@ -318,6 +320,7 @@ def load_tokenizer(cfg):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||||
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
||||||
|
|
||||||
@@ -1192,18 +1195,17 @@ class ModelLoader:
|
|||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
*,
|
*,
|
||||||
processor: ProcessorMixin = None, # pylint: disable=unused-argument
|
processor: ProcessorMixin = None,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs,
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, PeftConfig | None]:
|
||||||
"""
|
"""Load a model for a given configuration and tokenizer"""
|
||||||
Load a model for a given configuration and tokenizer.
|
|
||||||
"""
|
|
||||||
loader = ModelLoader(
|
loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -1215,6 +1217,7 @@ def load_model(
|
|||||||
return loader.load_model()
|
return loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
@track_errors
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
|
|||||||
@@ -1,71 +1,65 @@
|
|||||||
|
"""Tests for TelemetryManager class and utilities"""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
|
from axolotl.telemetry import TelemetryManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_whitelist(tmp_path):
|
def mock_whitelist(tmp_path):
|
||||||
"""Create a temporary whitelist file for testing"""
|
"""Create a temporary whitelist file for testing"""
|
||||||
whitelist_content = {
|
whitelist_content = {
|
||||||
"organizations": ["meta", "mistral"],
|
"organizations": ["meta-llama", "mistralai"],
|
||||||
"models": ["llama", "mistral-7b"],
|
|
||||||
}
|
}
|
||||||
whitelist_file = tmp_path / "whitelist.yaml"
|
whitelist_file = tmp_path / "whitelist.yaml"
|
||||||
with open(whitelist_file, "w") as f:
|
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||||
yaml.dump(whitelist_content, f)
|
yaml.dump(whitelist_content, f)
|
||||||
return str(whitelist_file)
|
return str(whitelist_file)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def config(mock_whitelist):
|
def manager():
|
||||||
"""Create a TelemetryConfig with test settings"""
|
|
||||||
return TelemetryConfig(
|
|
||||||
host="https://test.posthog.com", whitelist_path=mock_whitelist
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def manager(config):
|
|
||||||
"""Create a TelemetryManager instance with mocked PostHog"""
|
"""Create a TelemetryManager instance with mocked PostHog"""
|
||||||
with patch("posthog.capture"):
|
with patch("posthog.capture"):
|
||||||
return TelemetryManager(config)
|
return TelemetryManager()
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_by_default():
|
def test_telemetry_disabled_by_default():
|
||||||
"""Test that telemetry is disabled by default"""
|
"""Test that telemetry is disabled by default"""
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_opt_in():
|
def test_telemetry_opt_in():
|
||||||
"""Test that telemetry can be enabled via environment variable"""
|
"""Test that telemetry can be enabled via environment variable"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager()
|
||||||
assert manager.enabled
|
assert manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_do_not_track_override():
|
def test_do_not_track_override():
|
||||||
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
||||||
manager = TelemetryManager(TelemetryConfig())
|
manager = TelemetryManager()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
def test_whitelist_checking(manager):
|
def test_whitelist_checking(manager):
|
||||||
"""Test model whitelist functionality"""
|
"""Test model whitelist functionality"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
# Should match organization
|
||||||
# Should match organization
|
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||||
assert manager._is_whitelisted("meta/llama-7b")
|
# Should match model name
|
||||||
# Should match model name
|
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||||
assert manager._is_whitelisted("mistral-7b-instruct")
|
# Should not match
|
||||||
# Should not match
|
assert not manager._is_whitelisted("unknown/model")
|
||||||
assert not manager._is_whitelisted("unknown/model")
|
# Should handle case insensitively
|
||||||
# Should handle case insensitively
|
assert manager._is_whitelisted("meta/Llama-7b")
|
||||||
assert manager._is_whitelisted("meta/Llama-7b")
|
|
||||||
|
|
||||||
|
|
||||||
def test_event_tracking(manager):
|
def test_event_tracking(manager):
|
||||||
@@ -81,33 +75,6 @@ def test_event_tracking(manager):
|
|||||||
assert "system_info" in mock_capture.call_args[1]["properties"]
|
assert "system_info" in mock_capture.call_args[1]["properties"]
|
||||||
|
|
||||||
|
|
||||||
def test_model_tracking(manager):
|
|
||||||
"""Test model load tracking"""
|
|
||||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
|
||||||
model_config = ModelConfig(
|
|
||||||
base_model="meta/llama-7b",
|
|
||||||
model_type="decoder",
|
|
||||||
hidden_size=4096,
|
|
||||||
num_layers=32,
|
|
||||||
num_attention_heads=32,
|
|
||||||
tokenizer_config={},
|
|
||||||
flash_attention=True,
|
|
||||||
quantization_config=None,
|
|
||||||
training_approach="lora",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
|
||||||
manager.enabled = True
|
|
||||||
manager.track_model_load(model_config)
|
|
||||||
|
|
||||||
assert mock_capture.called
|
|
||||||
assert mock_capture.call_args[1]["event"] == "model_load"
|
|
||||||
assert (
|
|
||||||
mock_capture.call_args[1]["properties"]["model_config"]
|
|
||||||
== model_config.to_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_training_context(manager):
|
def test_training_context(manager):
|
||||||
"""Test training context manager"""
|
"""Test training context manager"""
|
||||||
config = {"model": "llama", "batch_size": 8}
|
config = {"model": "llama", "batch_size": 8}
|
||||||
@@ -141,6 +108,7 @@ def test_training_error(manager):
|
|||||||
assert "training_error" in events
|
assert "training_error" in events
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
def test_path_sanitization(manager):
|
def test_path_sanitization(manager):
|
||||||
"""Test path sanitization"""
|
"""Test path sanitization"""
|
||||||
path = "/home/user/sensitive/data.txt"
|
path = "/home/user/sensitive/data.txt"
|
||||||
@@ -149,6 +117,7 @@ def test_path_sanitization(manager):
|
|||||||
assert "/home/user" not in sanitized
|
assert "/home/user" not in sanitized
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
def test_error_sanitization(manager):
|
def test_error_sanitization(manager):
|
||||||
"""Test error message sanitization"""
|
"""Test error message sanitization"""
|
||||||
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
||||||
|
|||||||
Reference in New Issue
Block a user