adding runtime metrics (cpu + gpu memory, steps/s, etc.)

This commit is contained in:
Dan Saunders
2025-02-21 19:01:35 +00:00
parent 7927abff90
commit 3760175440
5 changed files with 338 additions and 18 deletions

View File

@@ -61,6 +61,8 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -176,10 +178,8 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -188,6 +188,10 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -0,0 +1,157 @@
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
import logging
import time
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.telemetry.manager import TelemetryManager
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
LOG = logging.getLogger(__name__)
TIME_SINCE_LAST = 30
class TelemetryCallback(TrainerCallback):
"""
Trainer callback for tracking and reporting runtime metrics.
This callback tracks training progress, runtime, and memory usage,
sending telemetry at configurable intervals.
"""
report_interval_steps: int = 100
def __init__(self):
"""Initialize the metrics callback."""
self.tracker = RuntimeMetricsTracker()
self.telemetry_manager = TelemetryManager.get_instance()
self.current_epoch = -1
self.start_time = time.time()
self.last_report_time = self.start_time
self.last_report_step = 0
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle training start."""
self.telemetry_manager.send_event(event_type="train-start")
def on_train_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle training end."""
# Send training completion event
self.telemetry_manager.send_event(
event_type="train-end",
properties={
"loss": state.log_history[-1].get("loss", 0)
if state.log_history
else None,
"learning_rate": state.log_history[-1].get("learning_rate", 0)
if state.log_history
else None,
}
| self.tracker.metrics.to_dict(),
)
def on_epoch_begin(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle epoch start."""
self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch)
def on_epoch_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle epoch end."""
self.tracker.end_epoch(self.current_epoch)
def on_step_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle step end."""
step = state.global_step
self.tracker.update_step(step)
# Check if we should report metrics
should_report = (
step % self.report_interval_steps == 0
or step == 1 # Always report first step
or step - self.last_report_step >= self.report_interval_steps
)
if should_report:
current_time = time.time()
time_since_last_report = current_time - self.last_report_time
steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed to avoid flooding
if (
time_since_last_report >= TIME_SINCE_LAST
or steps_since_last_report >= self.report_interval_steps
):
# Calculate steps per second for this interval
if time_since_last_report > 0 and steps_since_last_report > 0:
steps_per_second = steps_since_last_report / time_since_last_report
else:
steps_per_second = 0
# Update memory metrics
self.tracker.update_memory_metrics()
# Prepare metrics to report
metrics = {
"step": step,
"epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress
"loss": state.log_history[-1].get("loss", 0)
if state.log_history
else 0,
"learning_rate": state.log_history[-1].get("learning_rate", 0)
if state.log_history
else 0,
"steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report,
}
# Add memory metrics
memory_metrics = self.tracker.get_memory_metrics()
metrics.update(memory_metrics)
# Send telemetry
self.telemetry_manager.send_event(
event_type="train-progress", properties=metrics
)
# Update last report time and step
self.last_report_time = current_time
self.last_report_step = step

View File

@@ -141,7 +141,7 @@ class TelemetryManager:
return enabled, explicit_enabled
def _load_whitelist(self) -> dict:
"""Load organization/model whitelist"""
"""Load HuggingFace Hub organization whitelist"""
with open(self.config.whitelist_path, encoding="utf-8") as f:
return yaml.safe_load(f)

View File

@@ -0,0 +1,167 @@
"""Telemetry utilities for runtime and memory metrics."""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import psutil
import torch
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
@dataclass
class RuntimeMetrics:
"""Container for runtime metrics to be tracked throughout training."""
# Timing metrics
start_time: float
epoch_start_times: dict[int, float] = field(init=False)
epoch_end_times: dict[int, float] = field(init=False)
# Memory metrics
peak_cpu_memory: int = 0
peak_gpu_memory: dict[int, int] = field(init=False)
# Progress metrics
total_steps: int = 0
current_epoch: int = 0
current_step: int = 0
def __post_init__(self):
"""Initialize empty metric mappings."""
self.epoch_start_times = {}
self.epoch_end_times = {}
self.peak_gpu_memory = {}
@property
def elapsed_time(self) -> float:
"""Calculate total elapsed time in seconds."""
return time.time() - self.start_time
def epoch_time(self, epoch: int) -> float | None:
"""Calculate time taken for a specific epoch in seconds."""
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
return None
def average_epoch_time(self) -> float | None:
"""Calculate average time per epoch in seconds."""
completed_epochs = [
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
]
if not completed_epochs:
return None
total_time = 0.0
for epoch in completed_epochs:
epoch_time = self.epoch_time(epoch)
if epoch_time is not None: # Check to avoid mypy warning
total_time += epoch_time
return total_time / len(completed_epochs)
def steps_per_second(self) -> float | None:
"""Calculate average steps per second across all training."""
if self.total_steps == 0 or self.elapsed_time == 0:
return None
return self.total_steps / self.elapsed_time
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to a dictionary for telemetry reporting."""
metrics = {
"total_time_seconds": self.elapsed_time,
"total_steps": self.total_steps,
"steps_per_second": self.steps_per_second(),
"epochs_completed": len(
[
epoch
for epoch in self.epoch_start_times
if epoch in self.epoch_end_times
]
),
"peak_cpu_memory_bytes": self.peak_cpu_memory,
}
# Add per-epoch timing if available
epoch_times: dict[str, float] = {}
for epoch in sorted(self.epoch_end_times.keys()):
time_taken = self.epoch_time(epoch)
if time_taken is not None:
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
if epoch_times:
metrics["epoch_times"] = epoch_times # type: ignore
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
# Add GPU memory metrics if available
if self.peak_gpu_memory:
gpu_metrics: dict[str, int] = {}
for gpu_id, memory in self.peak_gpu_memory.items():
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
metrics["gpu_memory"] = gpu_metrics # type: ignore
return metrics
class RuntimeMetricsTracker:
"""Tracker for runtime metrics during training."""
def __init__(self):
"""Initialize the runtime metrics tracker."""
self.metrics = RuntimeMetrics(start_time=time.time())
self.telemetry_manager = TelemetryManager.get_instance()
def start_epoch(self, epoch: int):
"""Record the start of a new epoch."""
self.metrics.current_epoch = epoch
self.metrics.epoch_start_times[epoch] = time.time()
self.update_memory_metrics()
def end_epoch(self, epoch: int):
"""Record the end of an epoch."""
self.metrics.epoch_end_times[epoch] = time.time()
def update_step(self, step: int):
"""Update the current step count."""
self.metrics.current_step = step
self.metrics.total_steps += 1
# Periodically update memory metrics (e.g., every 100 steps)
if step % 100 == 0:
self.update_memory_metrics()
def update_memory_metrics(self):
"""Update peak memory usage metrics."""
# CPU memory
cpu_memory = psutil.Process().memory_info().rss
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
# GPU memory if available
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_used = torch.cuda.memory_allocated(i)
self.metrics.peak_gpu_memory[i] = max(
self.metrics.peak_gpu_memory.get(i, 0), memory_used
)
def get_memory_metrics(self) -> dict[str, Any]:
"""Get the current memory metrics as a dictionary."""
memory_metrics = {
"cpu_memory_bytes": psutil.Process().memory_info().rss,
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_metrics[f"gpu_{i}_memory_bytes"] = torch.cuda.memory_allocated(i)
memory_metrics[
f"gpu_{i}_peak_memory_bytes"
] = self.metrics.peak_gpu_memory.get(i, 0)
return {"memory": memory_metrics}

View File

@@ -89,14 +89,13 @@ def train(
if model.generation_config is not None:
model.generation_config.do_sample = True
if TELEMETRY_MANAGER.enabled:
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
event_type="peft-config-load", properties=peft_config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
model_ref = None
if cfg.rl and cfg.rl != "orpo":
@@ -188,9 +187,6 @@ def train(
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if TELEMETRY_MANAGER.enabled:
TELEMETRY_MANAGER.send_event(event_type="train-start")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -201,10 +197,6 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if TELEMETRY_MANAGER.enabled:
TELEMETRY_MANAGER.send_event(event_type="train-end")
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training