adding runtime metrics / system info additional accelerator support, etc.
This commit is contained in:
@@ -31,8 +31,8 @@ Telemetry is implemented using PostHog and consists of:
|
|||||||
telemetry system and provides methods for tracking events.
|
telemetry system and provides methods for tracking events.
|
||||||
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
||||||
sends sanitized stack traces.
|
sends sanitized stack traces.
|
||||||
- `axolotl.telemetry.runtime_metrics.RuntimeMetrics`: A dataclass that tracks runtime
|
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
|
||||||
metrics during training.
|
runtime metrics during training.
|
||||||
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
||||||
runtime metrics telemetry.
|
runtime metrics telemetry.
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ aware of data collection, unless telemetry is explicitly enabled or disabled.
|
|||||||
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
|
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
|
||||||
|
|
||||||
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
|
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
|
||||||
- `DO_NOT_TRACK=1` (Global standard)
|
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
|
||||||
|
|
||||||
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
|
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
|
||||||
`AXOLOTL_DO_NOT_TRACK=0`.
|
`AXOLOTL_DO_NOT_TRACK=0`.
|
||||||
|
|||||||
@@ -131,17 +131,20 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
# Update memory metrics
|
# Update memory metrics
|
||||||
self.tracker.update_memory_metrics()
|
self.tracker.update_memory_metrics()
|
||||||
|
|
||||||
|
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
|
||||||
|
learning_rate = (
|
||||||
|
state.log_history[-1].get("learning_rate", 0)
|
||||||
|
if state.log_history
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare metrics to report
|
# Prepare metrics to report
|
||||||
metrics = {
|
metrics = {
|
||||||
"step": step,
|
"step": step,
|
||||||
"epoch": self.current_epoch,
|
"epoch": self.current_epoch,
|
||||||
"progress": state.epoch, # Fractional epoch progress
|
"progress": state.epoch, # Fractional epoch progress
|
||||||
"loss": state.log_history[-1].get("loss", 0)
|
"loss": loss,
|
||||||
if state.log_history
|
"learning_rate": learning_rate,
|
||||||
else 0,
|
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", 0)
|
|
||||||
if state.log_history
|
|
||||||
else 0,
|
|
||||||
"steps_per_second": steps_per_second,
|
"steps_per_second": steps_per_second,
|
||||||
"elapsed_time": current_time - self.start_time,
|
"elapsed_time": current_time - self.start_time,
|
||||||
"time_since_last_report": time_since_last_report,
|
"time_since_last_report": time_since_last_report,
|
||||||
@@ -149,7 +152,7 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
|
|
||||||
# Add memory metrics
|
# Add memory metrics
|
||||||
memory_metrics = self.tracker.get_memory_metrics()
|
memory_metrics = self.tracker.get_memory_metrics()
|
||||||
metrics.update(memory_metrics)
|
metrics.update({"memory": memory_metrics})
|
||||||
|
|
||||||
# Send telemetry
|
# Send telemetry
|
||||||
self.telemetry_manager.send_event(
|
self.telemetry_manager.send_event(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Telemetry manager and associated utilities."""
|
"""Telemetry manager and associated utilities."""
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@@ -12,10 +13,8 @@ from typing import Any
|
|||||||
import posthog
|
import posthog
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import axolotl
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -32,7 +31,7 @@ ENABLED_WARNING = (
|
|||||||
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
||||||
"To disable telemetry, set either:\n"
|
"To disable telemetry, set either:\n"
|
||||||
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
||||||
"- DO_NOT_TRACK=1 (Global standard)\n\n"
|
"- DO_NOT_TRACK=1 (Global standard; see https://consoledonottrack.com/)\n\n"
|
||||||
"To remove this warning and continue with telemetry enabled,"
|
"To remove this warning and continue with telemetry enabled,"
|
||||||
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
||||||
"No personally identifiable information is collected."
|
"No personally identifiable information is collected."
|
||||||
@@ -42,13 +41,39 @@ ENABLED_WARNING = (
|
|||||||
|
|
||||||
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
||||||
|
|
||||||
FIELDS_WITH_ORGS = [
|
# NOTE: Keep these up to date with any config schema changes
|
||||||
|
FIELDS_WITH_ORGS = {
|
||||||
"base_model",
|
"base_model",
|
||||||
"tokenizer_config",
|
"tokenizer_config",
|
||||||
"base_model_config",
|
"base_model_config",
|
||||||
]
|
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
|
||||||
FIELDS_TO_REDACT = ["resume_from_checkpoint", "hub_model_id"]
|
}
|
||||||
PREFIXES_TO_REDACT = ["wandb_", "comet_", "mlflow_", "gradio_"]
|
FIELDS_TO_REDACT = {"resume_from_checkpoint", "hub_model_id"}
|
||||||
|
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
||||||
|
PATH_INDICATORS = {"path", "dir"}
|
||||||
|
|
||||||
|
RELEVANT_PACKAGES = {
|
||||||
|
"torch",
|
||||||
|
"transformers",
|
||||||
|
"trl",
|
||||||
|
"datasets",
|
||||||
|
"peft",
|
||||||
|
"bitsandbytes",
|
||||||
|
"accelerate",
|
||||||
|
"optimum",
|
||||||
|
"deepspeed",
|
||||||
|
"ray",
|
||||||
|
"axolotl",
|
||||||
|
"triton",
|
||||||
|
"mamba-ssm",
|
||||||
|
"flash-attn",
|
||||||
|
"xformers",
|
||||||
|
"autoawq",
|
||||||
|
"tokenizers",
|
||||||
|
"sentencepiece",
|
||||||
|
"torchao",
|
||||||
|
"lm_eval",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
class TelemetryManager:
|
||||||
@@ -78,7 +103,13 @@ class TelemetryManager:
|
|||||||
if self.enabled:
|
if self.enabled:
|
||||||
self.run_id = str(uuid.uuid4())
|
self.run_id = str(uuid.uuid4())
|
||||||
self.whitelist = self._load_whitelist()
|
self.whitelist = self._load_whitelist()
|
||||||
self.system_info = self._get_system_info()
|
|
||||||
|
try:
|
||||||
|
self.system_info = self._get_system_info()
|
||||||
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.warning(f"Error during system info collection: {e}")
|
||||||
|
self.system_info = None
|
||||||
|
|
||||||
self._init_posthog()
|
self._init_posthog()
|
||||||
|
|
||||||
# Register shutdown method to flush posthog telemetry
|
# Register shutdown method to flush posthog telemetry
|
||||||
@@ -174,9 +205,6 @@ class TelemetryManager:
|
|||||||
if not properties:
|
if not properties:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# NOTE: Keep this up to date with any config schema changes
|
|
||||||
path_indicators = {"path", "dir"}
|
|
||||||
|
|
||||||
def redact_value(value: Any, key: str = "") -> Any:
|
def redact_value(value: Any, key: str = "") -> Any:
|
||||||
"""Recursively sanitize values, redacting those with path-like keys"""
|
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||||
if isinstance(key, str) and isinstance(value, str):
|
if isinstance(key, str) and isinstance(value, str):
|
||||||
@@ -190,7 +218,7 @@ class TelemetryManager:
|
|||||||
if (
|
if (
|
||||||
key in FIELDS_TO_REDACT
|
key in FIELDS_TO_REDACT
|
||||||
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
||||||
or any(indicator in key.lower() for indicator in path_indicators)
|
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
||||||
):
|
):
|
||||||
return "[REDACTED]"
|
return "[REDACTED]"
|
||||||
|
|
||||||
@@ -208,27 +236,100 @@ class TelemetryManager:
|
|||||||
return redacted
|
return redacted
|
||||||
|
|
||||||
def _get_system_info(self) -> dict[str, Any]:
|
def _get_system_info(self) -> dict[str, Any]:
|
||||||
"""Collect system information"""
|
"""Collect system information for various hardware accelerators"""
|
||||||
gpu_info = []
|
gpu_info = []
|
||||||
|
accelerator_type = "none"
|
||||||
|
|
||||||
|
# NVIDIA GPUs
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
accelerator_type = "cuda"
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
gpu_info.append(
|
gpu_info.append(
|
||||||
{
|
{
|
||||||
"name": torch.cuda.get_device_name(i),
|
"name": torch.cuda.get_device_name(i),
|
||||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
"memory": torch.cuda.get_device_properties(i).total_memory,
|
||||||
|
"type": "cuda",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# AMD GPUs
|
||||||
|
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||||
|
accelerator_type = "hip"
|
||||||
|
for i in range(torch.hip.device_count()):
|
||||||
|
gpu_info.append(
|
||||||
|
{
|
||||||
|
"name": torch.hip.get_device_name(i),
|
||||||
|
"memory": torch.hip.get_device_properties(i).total_memory
|
||||||
|
if hasattr(torch.hip, "get_device_properties")
|
||||||
|
else None,
|
||||||
|
"type": "hip",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apple Silicon
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
accelerator_type = "mps"
|
||||||
|
gpu_info.append(
|
||||||
|
{
|
||||||
|
"name": "Apple Silicon",
|
||||||
|
# NOTE: this is memory allocated to this process, not total memory
|
||||||
|
"memory": torch.mps.driver_allocated_memory(),
|
||||||
|
"type": "mps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Intel GPUs
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
accelerator_type = "xpu"
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
memory = None
|
||||||
|
if hasattr(torch.xpu, "get_device_properties"):
|
||||||
|
memory = torch.xpu.get_device_properties(i).total_memory
|
||||||
|
|
||||||
|
gpu_info.append(
|
||||||
|
{
|
||||||
|
"name": torch.xpu.get_device_name(i),
|
||||||
|
"memory": memory,
|
||||||
|
"type": "xpu",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# NPUs
|
||||||
|
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||||
|
accelerator_type = "npu"
|
||||||
|
for i in range(torch.npu.device_count()):
|
||||||
|
name = getattr(torch.npu, "get_device_name", lambda x: "NPU")(i)
|
||||||
|
|
||||||
|
memory = None
|
||||||
|
if hasattr(torch.npu, "get_device_properties"):
|
||||||
|
memory = torch.npu.get_device_properties(i).total_memory
|
||||||
|
|
||||||
|
gpu_info.append(
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"memory": memory,
|
||||||
|
"type": "npu",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get relevant package versions
|
||||||
|
installed_packages = {}
|
||||||
|
for package in RELEVANT_PACKAGES:
|
||||||
|
try:
|
||||||
|
version = importlib.metadata.version(package)
|
||||||
|
installed_packages[f"{package}_version"] = version
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"os": platform.system(),
|
"os": platform.system(),
|
||||||
"python_version": platform.python_version(),
|
"python_version": platform.python_version(),
|
||||||
"pytorch_version": torch.__version__,
|
|
||||||
"transformers_version": transformers.__version__,
|
|
||||||
"axolotl_version": axolotl.__version__,
|
|
||||||
"cpu_count": psutil.cpu_count(),
|
"cpu_count": psutil.cpu_count(),
|
||||||
"memory_total": psutil.virtual_memory().total,
|
"memory_total": psutil.virtual_memory().total,
|
||||||
"gpu_count": len(gpu_info),
|
"accelerator_type": accelerator_type,
|
||||||
"gpu_info": gpu_info,
|
"accelerator_count": len(gpu_info),
|
||||||
|
"accelerator_info": gpu_info,
|
||||||
|
**installed_packages,
|
||||||
}
|
}
|
||||||
|
|
||||||
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||||
|
|||||||
@@ -112,6 +112,8 @@ class RuntimeMetrics:
|
|||||||
class RuntimeMetricsTracker:
|
class RuntimeMetricsTracker:
|
||||||
"""Tracker for runtime metrics during training."""
|
"""Tracker for runtime metrics during training."""
|
||||||
|
|
||||||
|
update_interval = 100
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the runtime metrics tracker."""
|
"""Initialize the runtime metrics tracker."""
|
||||||
self.metrics = RuntimeMetrics(start_time=time.time())
|
self.metrics = RuntimeMetrics(start_time=time.time())
|
||||||
@@ -132,23 +134,62 @@ class RuntimeMetricsTracker:
|
|||||||
self.metrics.current_step = step
|
self.metrics.current_step = step
|
||||||
self.metrics.total_steps += 1
|
self.metrics.total_steps += 1
|
||||||
|
|
||||||
# Periodically update memory metrics (e.g., every 100 steps)
|
# Periodically update memory metrics
|
||||||
if step % 100 == 0:
|
if step % self.update_interval == 0:
|
||||||
self.update_memory_metrics()
|
self.update_memory_metrics()
|
||||||
|
|
||||||
|
def _get_allocated_memory(self) -> dict[int, int]:
|
||||||
|
"""
|
||||||
|
Helper function for getting accelerator-agnostic allocated memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping device IDs to allocated memory in bytes
|
||||||
|
"""
|
||||||
|
memory_used: dict[int, int] = {}
|
||||||
|
|
||||||
|
# NVIDIA GPUs
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
memory_used[i] = torch.cuda.memory_allocated(i)
|
||||||
|
|
||||||
|
# AMD GPUs
|
||||||
|
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||||
|
for i in range(torch.hip.device_count()):
|
||||||
|
if hasattr(torch.hip, "memory_allocated"):
|
||||||
|
memory_used[i] = torch.hip.memory_allocated(i)
|
||||||
|
|
||||||
|
# Apple Silicon
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
# MPS doesn't have per-device memory stats since there's only one device
|
||||||
|
if hasattr(torch.mps, "current_allocated_memory"):
|
||||||
|
memory_used[0] = torch.mps.current_allocated_memory()
|
||||||
|
|
||||||
|
# Intel GPUs
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
if hasattr(torch.xpu, "memory_allocated"):
|
||||||
|
memory_used[i] = torch.xpu.memory_allocated(i)
|
||||||
|
|
||||||
|
# NPUs
|
||||||
|
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||||
|
for i in range(torch.npu.device_count()):
|
||||||
|
if hasattr(torch.npu, "memory_allocated"):
|
||||||
|
memory_used[i] = torch.npu.memory_allocated(i)
|
||||||
|
|
||||||
|
return memory_used
|
||||||
|
|
||||||
def update_memory_metrics(self):
|
def update_memory_metrics(self):
|
||||||
"""Update peak memory usage metrics."""
|
"""Update peak memory usage metrics."""
|
||||||
# CPU memory
|
# CPU memory
|
||||||
cpu_memory = psutil.Process().memory_info().rss
|
cpu_memory = psutil.Process().memory_info().rss
|
||||||
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
||||||
|
|
||||||
# GPU memory if available
|
# GPU memory (if available)
|
||||||
if torch.cuda.is_available():
|
memory_used = self._get_allocated_memory()
|
||||||
for i in range(torch.cuda.device_count()):
|
for i, memory in memory_used.items():
|
||||||
memory_used = torch.cuda.memory_allocated(i)
|
self.metrics.peak_gpu_memory[i] = max(
|
||||||
self.metrics.peak_gpu_memory[i] = max(
|
self.metrics.peak_gpu_memory.get(i, 0), memory
|
||||||
self.metrics.peak_gpu_memory.get(i, 0), memory_used
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def get_memory_metrics(self) -> dict[str, Any]:
|
def get_memory_metrics(self) -> dict[str, Any]:
|
||||||
"""Get the current memory metrics as a dictionary."""
|
"""Get the current memory metrics as a dictionary."""
|
||||||
@@ -157,11 +198,12 @@ class RuntimeMetricsTracker:
|
|||||||
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
# GPU memory (if available)
|
||||||
for i in range(torch.cuda.device_count()):
|
memory_used = self._get_allocated_memory()
|
||||||
memory_metrics[f"gpu_{i}_memory_bytes"] = torch.cuda.memory_allocated(i)
|
for i, memory in memory_used.items():
|
||||||
memory_metrics[
|
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
||||||
f"gpu_{i}_peak_memory_bytes"
|
memory_metrics[
|
||||||
] = self.metrics.peak_gpu_memory.get(i, 0)
|
f"gpu_{i}_peak_memory_bytes"
|
||||||
|
] = self.metrics.peak_gpu_memory.get(i, 0)
|
||||||
|
|
||||||
return {"memory": memory_metrics}
|
return memory_metrics
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
organizations:
|
organizations:
|
||||||
|
- "axolotl-ai-co"
|
||||||
- "meta-llama"
|
- "meta-llama"
|
||||||
- "huggingface"
|
- "huggingface"
|
||||||
- "nvidia"
|
- "nvidia"
|
||||||
- "facebook"
|
- "facebook"
|
||||||
|
- "google"
|
||||||
|
- "microsoft"
|
||||||
|
- "deepseek-ai"
|
||||||
|
- "HuggingFaceTB"
|
||||||
- "mistralai"
|
- "mistralai"
|
||||||
|
- "Qwen"
|
||||||
- "briaai"
|
- "briaai"
|
||||||
- "unsloth"
|
- "unsloth"
|
||||||
- "NousResearch"
|
- "NousResearch"
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def setup_model_and_tokenizer(
|
|||||||
"""
|
"""
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_cocnfig or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
@@ -597,9 +597,7 @@ def train(
|
|||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
TELEMETRY_MANAGER.send_event(event_type="train-start")
|
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
TELEMETRY_MANAGER.send_event(event_type="train-end")
|
|
||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
|||||||
@@ -151,12 +151,12 @@ def test_system_info_collection(manager):
|
|||||||
# Check essential keys
|
# Check essential keys
|
||||||
assert "os" in system_info
|
assert "os" in system_info
|
||||||
assert "python_version" in system_info
|
assert "python_version" in system_info
|
||||||
assert "pytorch_version" in system_info
|
assert "torch_version" in system_info
|
||||||
assert "transformers_version" in system_info
|
assert "transformers_version" in system_info
|
||||||
assert "axolotl_version" in system_info
|
assert "axolotl_version" in system_info
|
||||||
assert "cpu_count" in system_info
|
assert "cpu_count" in system_info
|
||||||
assert "memory_total" in system_info
|
assert "memory_total" in system_info
|
||||||
assert "gpu_count" in system_info
|
assert "accelerator_count" in system_info
|
||||||
|
|
||||||
|
|
||||||
def test_send_event(manager):
|
def test_send_event(manager):
|
||||||
|
|||||||
@@ -331,30 +331,26 @@ class TestRuntimeMetricsTracker:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Get memory metrics
|
# Get memory metrics
|
||||||
result = tracker.get_memory_metrics()
|
memory_metrics = tracker.get_memory_metrics()
|
||||||
|
|
||||||
# Verify structure
|
|
||||||
assert "memory" in result
|
|
||||||
memory = result["memory"]
|
|
||||||
|
|
||||||
# Verify CPU memory
|
# Verify CPU memory
|
||||||
assert (
|
assert (
|
||||||
memory["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||||
) # Current value from mock
|
) # Current value from mock
|
||||||
assert (
|
assert (
|
||||||
memory["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||||
) # Peak value we set
|
) # Peak value we set
|
||||||
|
|
||||||
# Verify GPU memory
|
# Verify GPU memory
|
||||||
assert (
|
assert (
|
||||||
memory["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||||
) # Current value from mock
|
) # Current value from mock
|
||||||
assert (
|
assert (
|
||||||
memory["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
||||||
) # Peak value we set
|
) # Peak value we set
|
||||||
assert (
|
assert (
|
||||||
memory["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||||
) # Current value from mock
|
) # Current value from mock
|
||||||
assert (
|
assert (
|
||||||
memory["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
||||||
) # Peak value we set
|
) # Peak value we set
|
||||||
|
|||||||
Reference in New Issue
Block a user