update error file path sanitization function; adding more error tracking
This commit is contained in:
@@ -3,4 +3,44 @@ title: Telemetry
|
|||||||
description: A description of the opt-out telemetry implementation in Axolotl.
|
description: A description of the opt-out telemetry implementation in Axolotl.
|
||||||
---
|
---
|
||||||
|
|
||||||
TODO.
|
# Telemetry in Axolotl
|
||||||
|
|
||||||
|
Axolotl implements anonymous telemetry to help maintainers understand how the library
|
||||||
|
is used and where users encounter issues. This data helps prioritize features, optimize
|
||||||
|
performance, and fix bugs.
|
||||||
|
|
||||||
|
## Data Collection
|
||||||
|
|
||||||
|
We collect:
|
||||||
|
|
||||||
|
- **System info**: OS, Python version, PyTorch version, Transformers version, Axolotl version
|
||||||
|
- **Hardware info**: CPU count, memory, GPU count and models
|
||||||
|
- **Usage patterns**: Models (from a whitelist) and configurations used
|
||||||
|
- **Error tracking**: Stack traces and error messages (sanitized to remove personal information)
|
||||||
|
|
||||||
|
No personally identifiable information (PII) is collected.
|
||||||
|
|
||||||
|
## Implementation
|
||||||
|
|
||||||
|
Telemetry is implemented using PostHog and consists of:
|
||||||
|
|
||||||
|
1. `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the telemetry system and provides methods for tracking events.
|
||||||
|
2. `axolotl.telemetry.errors.track_errors`: A decorator that captures exceptions and sends sanitized stack traces.
|
||||||
|
|
||||||
|
## Opt-Out Mechanism
|
||||||
|
|
||||||
|
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
|
||||||
|
|
||||||
|
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
|
||||||
|
- `DO_NOT_TRACK=1` (Global standard)
|
||||||
|
|
||||||
|
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
|
||||||
|
`AXOLOTL_DO_NOT_TRACK=0`
|
||||||
|
|
||||||
|
## Privacy
|
||||||
|
|
||||||
|
- Stack traces are sanitized to remove personal file paths, keeping only the Axolotl code paths
|
||||||
|
- Each run generates a unique anonymous ID
|
||||||
|
- Only whitelisted organization information is tracked
|
||||||
|
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
|
||||||
|
- Telemetry is only sent from the main process to avoid duplicate events
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.telemetry import TelemetryManager
|
from axolotl.telemetry import TelemetryManager
|
||||||
from axolotl.telemetry.manager import track_errors
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
@@ -163,7 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
|
|||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
|
|
||||||
|
|
||||||
@track_errors
|
@send_errors
|
||||||
def load_cfg(
|
def load_cfg(
|
||||||
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||||
) -> DictDefault:
|
) -> DictDefault:
|
||||||
@@ -197,7 +197,7 @@ def load_cfg(
|
|||||||
temp_file.close()
|
temp_file.close()
|
||||||
cfg.axolotl_config_path = temp_file.name
|
cfg.axolotl_config_path = temp_file.name
|
||||||
|
|
||||||
TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg)
|
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
@@ -240,6 +240,6 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg)
|
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from axolotl.cli.args import InferenceCliArgs
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
get_chat_template,
|
get_chat_template,
|
||||||
get_chat_template_from_config,
|
get_chat_template_from_config,
|
||||||
@@ -42,6 +43,7 @@ def get_multi_line_input() -> str:
|
|||||||
return instruction
|
return instruction
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def do_inference(
|
def do_inference(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -135,6 +137,7 @@ def do_inference(
|
|||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def do_inference_gradio(
|
def do_inference_gradio(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -9,12 +9,14 @@ from dotenv import load_dotenv
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||||
"""
|
"""
|
||||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
|||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights(
|
|||||||
return save_path_
|
return save_path_
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def merge_fsdp_weights(
|
def merge_fsdp_weights(
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
@@ -25,6 +26,7 @@ from axolotl.utils.trainer import disable_datasets_caching
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Preprocesses dataset specified in axolotl config.
|
Preprocesses dataset specified in axolotl config.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from datasets import Dataset
|
|||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.loaders import load_processor, load_tokenizer
|
from axolotl.loaders import load_processor, load_tokenizer
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -45,6 +46,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -112,6 +114,7 @@ def load_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def load_preference_datasets(
|
def load_preference_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from axolotl.train import (
|
|||||||
TrainDatasetMeta,
|
TrainDatasetMeta,
|
||||||
setup_model_and_tokenizer,
|
setup_model_and_tokenizer,
|
||||||
)
|
)
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
@@ -63,6 +65,7 @@ def evaluate_dataset(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
@send_errors
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets.
|
Evaluate a model on training and validation datasets.
|
||||||
|
|||||||
112
src/axolotl/telemetry/errors.py
Normal file
112
src/axolotl/telemetry/errors.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""Telemetry utilities for exception and traceback information."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
from functools import wraps
|
||||||
|
from inspect import getmodule
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ERROR_HANDLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_stack_trace(stack_trace: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove personal information from stack trace messages while keeping Axolotl codepaths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stack_trace: The original stack trace string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sanitized version of the stack trace with only axolotl paths preserved.
|
||||||
|
"""
|
||||||
|
# Split the stack trace into lines to process each file path separately
|
||||||
|
lines = stack_trace.split("\n")
|
||||||
|
sanitized_lines = []
|
||||||
|
|
||||||
|
# Regular expression to find file paths in the stack trace
|
||||||
|
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
# Check if this line contains a file path
|
||||||
|
path_match = path_pattern.search(line)
|
||||||
|
|
||||||
|
if path_match:
|
||||||
|
full_path = path_match.group(1)
|
||||||
|
|
||||||
|
if "axolotl/" in full_path:
|
||||||
|
# Keep only the 'axolotl' part and onward
|
||||||
|
axolotl_idx = full_path.rfind("axolotl/")
|
||||||
|
if axolotl_idx >= 0:
|
||||||
|
# Replace the original path with the sanitized one
|
||||||
|
sanitized_path = full_path[axolotl_idx:]
|
||||||
|
line = line.replace(full_path, sanitized_path)
|
||||||
|
else:
|
||||||
|
# For non-axolotl paths, replace with an empty string or a placeholder
|
||||||
|
line = line.replace(full_path, "")
|
||||||
|
|
||||||
|
sanitized_lines.append(line)
|
||||||
|
|
||||||
|
return "\n".join(sanitized_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def send_errors(func: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to send exception info in a function. If an exception is raised, we send
|
||||||
|
telemetry containing the stack trace and error message.
|
||||||
|
|
||||||
|
If an error occurs in a decorated function that is called by another decorated
|
||||||
|
function, we'll only send telemetry corresponding to the lower-level function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to decorate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorated function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
telemetry_manager = TelemetryManager.get_instance()
|
||||||
|
if not telemetry_manager.enabled:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as exception:
|
||||||
|
# Only track if we're not already handling an error. This prevents us from
|
||||||
|
# capturing an error more than once in nested decorated function calls.
|
||||||
|
global ERROR_HANDLED # pylint: disable=global-statement
|
||||||
|
if not ERROR_HANDLED:
|
||||||
|
ERROR_HANDLED = True
|
||||||
|
|
||||||
|
# Get function module path
|
||||||
|
module = getmodule(func)
|
||||||
|
module_path = (
|
||||||
|
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get stack trace
|
||||||
|
stack_trace = "".join(
|
||||||
|
traceback.format_exception(
|
||||||
|
type(exception), exception, exception.__traceback__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stack_trace = sanitize_stack_trace(stack_trace)
|
||||||
|
|
||||||
|
# Send error telemetry
|
||||||
|
telemetry_manager.send_event(
|
||||||
|
event_type=f"{module_path}-error",
|
||||||
|
properties={
|
||||||
|
"exception": str(exception),
|
||||||
|
"stack_trace": stack_trace,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -5,13 +5,10 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import wraps
|
|
||||||
from inspect import getmodule
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
|
||||||
import posthog
|
import posthog
|
||||||
import psutil
|
import psutil
|
||||||
@@ -24,8 +21,8 @@ from axolotl.utils.distributed import is_main_process
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
||||||
ENABLED_WARNING_SLEEP_SECONDS = 10
|
ENABLED_WARNING_SLEEP_SECONDS = 15
|
||||||
ENABLED_WARNING = (
|
ENABLED_WARNING = (
|
||||||
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
||||||
"- Which models and configurations are most commonly used\n"
|
"- Which models and configurations are most commonly used\n"
|
||||||
@@ -166,18 +163,6 @@ class TelemetryManager:
|
|||||||
"""Remove personal information from file paths"""
|
"""Remove personal information from file paths"""
|
||||||
return Path(path).name
|
return Path(path).name
|
||||||
|
|
||||||
def _sanitize_error(self, error: str) -> str:
|
|
||||||
"""Remove personal information from error messages"""
|
|
||||||
# Replace file paths with just filename
|
|
||||||
sanitized = error
|
|
||||||
try:
|
|
||||||
for path in Path(error).parents:
|
|
||||||
sanitized = sanitized.replace(str(path), "")
|
|
||||||
except (ValueError, RuntimeError) as e:
|
|
||||||
LOG.debug(f"Could not parse path in error message: {e}")
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
def _get_system_info(self) -> dict[str, Any]:
|
def _get_system_info(self) -> dict[str, Any]:
|
||||||
"""Collect system information"""
|
"""Collect system information"""
|
||||||
gpu_info = []
|
gpu_info = []
|
||||||
@@ -202,8 +187,8 @@ class TelemetryManager:
|
|||||||
"gpu_info": gpu_info,
|
"gpu_info": gpu_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
def track_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||||
"""Track a telemetry event"""
|
"""Send a telemetry event"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -218,63 +203,16 @@ class TelemetryManager:
|
|||||||
posthog.capture(
|
posthog.capture(
|
||||||
distinct_id=self.run_id,
|
distinct_id=self.run_id,
|
||||||
event=event_type,
|
event=event_type,
|
||||||
properties={
|
properties=properties,
|
||||||
"system_info": self.system_info,
|
|
||||||
**properties,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||||
|
|
||||||
|
def send_system_info(self):
|
||||||
|
"""Helper method for sending system info"""
|
||||||
|
self.send_event(event_type="system-info", properties=self.system_info)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Ensure all queued events are processed before shutdown"""
|
"""Ensure all queued events are processed before shutdown"""
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
posthog.flush()
|
posthog.flush()
|
||||||
|
|
||||||
|
|
||||||
ERROR_HANDLED = False
|
|
||||||
|
|
||||||
|
|
||||||
def track_errors(func: Callable) -> Callable:
|
|
||||||
"""Decorator to track errors in a function"""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs) -> Any:
|
|
||||||
telemetry_manager = TelemetryManager.get_instance()
|
|
||||||
if not telemetry_manager.enabled:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Exception as exception:
|
|
||||||
# Only track if we're not already handling an error. This prevents us from
|
|
||||||
# capturing an error more than once in nested decorated function calls.
|
|
||||||
global ERROR_HANDLED # pylint: disable=global-statement
|
|
||||||
if not ERROR_HANDLED:
|
|
||||||
ERROR_HANDLED = True
|
|
||||||
|
|
||||||
# Get function module path
|
|
||||||
module = getmodule(func)
|
|
||||||
module_path = (
|
|
||||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get stack trace
|
|
||||||
stack_trace = "".join(
|
|
||||||
traceback.format_exception(
|
|
||||||
type(exception), exception, exception.__traceback__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send error telemetry
|
|
||||||
telemetry_manager.track_event(
|
|
||||||
event_type=f"{module_path}-error",
|
|
||||||
properties={
|
|
||||||
"exception": str(exception),
|
|
||||||
"stack_trace": stack_trace,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from contextlib import ExitStack
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from axolotl.telemetry.manager import track_errors
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
@@ -33,6 +32,8 @@ from axolotl.loaders import (
|
|||||||
load_processor,
|
load_processor,
|
||||||
load_tokenizer,
|
load_tokenizer,
|
||||||
)
|
)
|
||||||
|
from axolotl.telemetry import TelemetryManager
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
@@ -540,7 +541,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@track_errors
|
@send_errors
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||||
@@ -565,9 +566,23 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
|
TELEMETRY_MANAGER.send_event(
|
||||||
|
event_type="model-load", properties=model.config.to_dict()
|
||||||
|
)
|
||||||
|
if peft_config:
|
||||||
|
TELEMETRY_MANAGER.send_event(
|
||||||
|
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.post_trainer_create(cfg, trainer)
|
plugin_manager.post_trainer_create(cfg, trainer)
|
||||||
|
|
||||||
|
# Determine if we need to resume from a checkpoint
|
||||||
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
|
|
||||||
|
# Configuration for saving
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
@@ -579,12 +594,11 @@ def train(
|
|||||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
TELEMETRY_MANAGER.track_event(event_type="train-start")
|
TELEMETRY_MANAGER.send_event(event_type="train-start")
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
TELEMETRY_MANAGER.track_event(event_type="train-end")
|
TELEMETRY_MANAGER.send_event(event_type="train-end")
|
||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
|||||||
1415
src/axolotl/utils/models.py
Normal file
1415
src/axolotl/utils/models.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user