adding runtime metrics (cpu + gpu memory, steps/s, etc.)
This commit is contained in:
@@ -61,6 +61,8 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
|
from axolotl.telemetry.callbacks import TelemetryCallback
|
||||||
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -176,10 +178,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
SaveAxolotlConfigtoMlflowCallback,
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks.extend(
|
callbacks.append(
|
||||||
[
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
if self.cfg.use_comet and is_comet_available():
|
if self.cfg.use_comet and is_comet_available():
|
||||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||||
@@ -188,6 +188,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
telemetry_manager = TelemetryManager.get_instance()
|
||||||
|
if telemetry_manager.enabled:
|
||||||
|
callbacks.append(TelemetryCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
|||||||
157
src/axolotl/telemetry/callbacks.py
Normal file
157
src/axolotl/telemetry/callbacks.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
|
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TIME_SINCE_LAST = 30
|
||||||
|
|
||||||
|
|
||||||
|
class TelemetryCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Trainer callback for tracking and reporting runtime metrics.
|
||||||
|
|
||||||
|
This callback tracks training progress, runtime, and memory usage,
|
||||||
|
sending telemetry at configurable intervals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
report_interval_steps: int = 100
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the metrics callback."""
|
||||||
|
self.tracker = RuntimeMetricsTracker()
|
||||||
|
self.telemetry_manager = TelemetryManager.get_instance()
|
||||||
|
self.current_epoch = -1
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.last_report_time = self.start_time
|
||||||
|
self.last_report_step = 0
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
"""Handle training start."""
|
||||||
|
self.telemetry_manager.send_event(event_type="train-start")
|
||||||
|
|
||||||
|
def on_train_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
"""Handle training end."""
|
||||||
|
# Send training completion event
|
||||||
|
self.telemetry_manager.send_event(
|
||||||
|
event_type="train-end",
|
||||||
|
properties={
|
||||||
|
"loss": state.log_history[-1].get("loss", 0)
|
||||||
|
if state.log_history
|
||||||
|
else None,
|
||||||
|
"learning_rate": state.log_history[-1].get("learning_rate", 0)
|
||||||
|
if state.log_history
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
| self.tracker.metrics.to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_epoch_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
"""Handle epoch start."""
|
||||||
|
self.current_epoch += 1
|
||||||
|
self.tracker.start_epoch(self.current_epoch)
|
||||||
|
|
||||||
|
def on_epoch_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
"""Handle epoch end."""
|
||||||
|
self.tracker.end_epoch(self.current_epoch)
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
"""Handle step end."""
|
||||||
|
step = state.global_step
|
||||||
|
self.tracker.update_step(step)
|
||||||
|
|
||||||
|
# Check if we should report metrics
|
||||||
|
should_report = (
|
||||||
|
step % self.report_interval_steps == 0
|
||||||
|
or step == 1 # Always report first step
|
||||||
|
or step - self.last_report_step >= self.report_interval_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_report:
|
||||||
|
current_time = time.time()
|
||||||
|
time_since_last_report = current_time - self.last_report_time
|
||||||
|
steps_since_last_report = step - self.last_report_step
|
||||||
|
|
||||||
|
# Only report if enough time has passed to avoid flooding
|
||||||
|
if (
|
||||||
|
time_since_last_report >= TIME_SINCE_LAST
|
||||||
|
or steps_since_last_report >= self.report_interval_steps
|
||||||
|
):
|
||||||
|
# Calculate steps per second for this interval
|
||||||
|
if time_since_last_report > 0 and steps_since_last_report > 0:
|
||||||
|
steps_per_second = steps_since_last_report / time_since_last_report
|
||||||
|
else:
|
||||||
|
steps_per_second = 0
|
||||||
|
|
||||||
|
# Update memory metrics
|
||||||
|
self.tracker.update_memory_metrics()
|
||||||
|
|
||||||
|
# Prepare metrics to report
|
||||||
|
metrics = {
|
||||||
|
"step": step,
|
||||||
|
"epoch": self.current_epoch,
|
||||||
|
"progress": state.epoch, # Fractional epoch progress
|
||||||
|
"loss": state.log_history[-1].get("loss", 0)
|
||||||
|
if state.log_history
|
||||||
|
else 0,
|
||||||
|
"learning_rate": state.log_history[-1].get("learning_rate", 0)
|
||||||
|
if state.log_history
|
||||||
|
else 0,
|
||||||
|
"steps_per_second": steps_per_second,
|
||||||
|
"elapsed_time": current_time - self.start_time,
|
||||||
|
"time_since_last_report": time_since_last_report,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add memory metrics
|
||||||
|
memory_metrics = self.tracker.get_memory_metrics()
|
||||||
|
metrics.update(memory_metrics)
|
||||||
|
|
||||||
|
# Send telemetry
|
||||||
|
self.telemetry_manager.send_event(
|
||||||
|
event_type="train-progress", properties=metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update last report time and step
|
||||||
|
self.last_report_time = current_time
|
||||||
|
self.last_report_step = step
|
||||||
@@ -141,7 +141,7 @@ class TelemetryManager:
|
|||||||
return enabled, explicit_enabled
|
return enabled, explicit_enabled
|
||||||
|
|
||||||
def _load_whitelist(self) -> dict:
|
def _load_whitelist(self) -> dict:
|
||||||
"""Load organization/model whitelist"""
|
"""Load HuggingFace Hub organization whitelist"""
|
||||||
with open(self.config.whitelist_path, encoding="utf-8") as f:
|
with open(self.config.whitelist_path, encoding="utf-8") as f:
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|||||||
167
src/axolotl/telemetry/runtime_metrics.py
Normal file
167
src/axolotl/telemetry/runtime_metrics.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""Telemetry utilities for runtime and memory metrics."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RuntimeMetrics:
|
||||||
|
"""Container for runtime metrics to be tracked throughout training."""
|
||||||
|
|
||||||
|
# Timing metrics
|
||||||
|
start_time: float
|
||||||
|
epoch_start_times: dict[int, float] = field(init=False)
|
||||||
|
epoch_end_times: dict[int, float] = field(init=False)
|
||||||
|
|
||||||
|
# Memory metrics
|
||||||
|
peak_cpu_memory: int = 0
|
||||||
|
peak_gpu_memory: dict[int, int] = field(init=False)
|
||||||
|
|
||||||
|
# Progress metrics
|
||||||
|
total_steps: int = 0
|
||||||
|
current_epoch: int = 0
|
||||||
|
current_step: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize empty metric mappings."""
|
||||||
|
self.epoch_start_times = {}
|
||||||
|
self.epoch_end_times = {}
|
||||||
|
self.peak_gpu_memory = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def elapsed_time(self) -> float:
|
||||||
|
"""Calculate total elapsed time in seconds."""
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
def epoch_time(self, epoch: int) -> float | None:
|
||||||
|
"""Calculate time taken for a specific epoch in seconds."""
|
||||||
|
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
|
||||||
|
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def average_epoch_time(self) -> float | None:
|
||||||
|
"""Calculate average time per epoch in seconds."""
|
||||||
|
completed_epochs = [
|
||||||
|
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
|
||||||
|
]
|
||||||
|
if not completed_epochs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
for epoch in completed_epochs:
|
||||||
|
epoch_time = self.epoch_time(epoch)
|
||||||
|
if epoch_time is not None: # Check to avoid mypy warning
|
||||||
|
total_time += epoch_time
|
||||||
|
|
||||||
|
return total_time / len(completed_epochs)
|
||||||
|
|
||||||
|
def steps_per_second(self) -> float | None:
|
||||||
|
"""Calculate average steps per second across all training."""
|
||||||
|
if self.total_steps == 0 or self.elapsed_time == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.total_steps / self.elapsed_time
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert metrics to a dictionary for telemetry reporting."""
|
||||||
|
metrics = {
|
||||||
|
"total_time_seconds": self.elapsed_time,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
"steps_per_second": self.steps_per_second(),
|
||||||
|
"epochs_completed": len(
|
||||||
|
[
|
||||||
|
epoch
|
||||||
|
for epoch in self.epoch_start_times
|
||||||
|
if epoch in self.epoch_end_times
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"peak_cpu_memory_bytes": self.peak_cpu_memory,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add per-epoch timing if available
|
||||||
|
epoch_times: dict[str, float] = {}
|
||||||
|
for epoch in sorted(self.epoch_end_times.keys()):
|
||||||
|
time_taken = self.epoch_time(epoch)
|
||||||
|
if time_taken is not None:
|
||||||
|
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
|
||||||
|
|
||||||
|
if epoch_times:
|
||||||
|
metrics["epoch_times"] = epoch_times # type: ignore
|
||||||
|
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
|
||||||
|
|
||||||
|
# Add GPU memory metrics if available
|
||||||
|
if self.peak_gpu_memory:
|
||||||
|
gpu_metrics: dict[str, int] = {}
|
||||||
|
for gpu_id, memory in self.peak_gpu_memory.items():
|
||||||
|
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
|
||||||
|
metrics["gpu_memory"] = gpu_metrics # type: ignore
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeMetricsTracker:
|
||||||
|
"""Tracker for runtime metrics during training."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the runtime metrics tracker."""
|
||||||
|
self.metrics = RuntimeMetrics(start_time=time.time())
|
||||||
|
self.telemetry_manager = TelemetryManager.get_instance()
|
||||||
|
|
||||||
|
def start_epoch(self, epoch: int):
|
||||||
|
"""Record the start of a new epoch."""
|
||||||
|
self.metrics.current_epoch = epoch
|
||||||
|
self.metrics.epoch_start_times[epoch] = time.time()
|
||||||
|
self.update_memory_metrics()
|
||||||
|
|
||||||
|
def end_epoch(self, epoch: int):
|
||||||
|
"""Record the end of an epoch."""
|
||||||
|
self.metrics.epoch_end_times[epoch] = time.time()
|
||||||
|
|
||||||
|
def update_step(self, step: int):
|
||||||
|
"""Update the current step count."""
|
||||||
|
self.metrics.current_step = step
|
||||||
|
self.metrics.total_steps += 1
|
||||||
|
|
||||||
|
# Periodically update memory metrics (e.g., every 100 steps)
|
||||||
|
if step % 100 == 0:
|
||||||
|
self.update_memory_metrics()
|
||||||
|
|
||||||
|
def update_memory_metrics(self):
|
||||||
|
"""Update peak memory usage metrics."""
|
||||||
|
# CPU memory
|
||||||
|
cpu_memory = psutil.Process().memory_info().rss
|
||||||
|
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
||||||
|
|
||||||
|
# GPU memory if available
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
memory_used = torch.cuda.memory_allocated(i)
|
||||||
|
self.metrics.peak_gpu_memory[i] = max(
|
||||||
|
self.metrics.peak_gpu_memory.get(i, 0), memory_used
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_memory_metrics(self) -> dict[str, Any]:
|
||||||
|
"""Get the current memory metrics as a dictionary."""
|
||||||
|
memory_metrics = {
|
||||||
|
"cpu_memory_bytes": psutil.Process().memory_info().rss,
|
||||||
|
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
||||||
|
}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
memory_metrics[f"gpu_{i}_memory_bytes"] = torch.cuda.memory_allocated(i)
|
||||||
|
memory_metrics[
|
||||||
|
f"gpu_{i}_peak_memory_bytes"
|
||||||
|
] = self.metrics.peak_gpu_memory.get(i, 0)
|
||||||
|
|
||||||
|
return {"memory": memory_metrics}
|
||||||
@@ -89,14 +89,13 @@ def train(
|
|||||||
if model.generation_config is not None:
|
if model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
if TELEMETRY_MANAGER.enabled:
|
TELEMETRY_MANAGER.send_event(
|
||||||
|
event_type="model-load", properties=model.config.to_dict()
|
||||||
|
)
|
||||||
|
if peft_config:
|
||||||
TELEMETRY_MANAGER.send_event(
|
TELEMETRY_MANAGER.send_event(
|
||||||
event_type="model-load", properties=model.config.to_dict()
|
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||||
)
|
)
|
||||||
if peft_config:
|
|
||||||
TELEMETRY_MANAGER.send_event(
|
|
||||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
model_ref = None
|
model_ref = None
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.rl and cfg.rl != "orpo":
|
||||||
@@ -188,9 +187,6 @@ def train(
|
|||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
if TELEMETRY_MANAGER.enabled:
|
|
||||||
TELEMETRY_MANAGER.send_event(event_type="train-start")
|
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
@@ -201,10 +197,6 @@ def train(
|
|||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
if TELEMETRY_MANAGER.enabled:
|
|
||||||
TELEMETRY_MANAGER.send_event(event_type="train-end")
|
|
||||||
|
|
||||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
# post training
|
# post training
|
||||||
|
|||||||
Reference in New Issue
Block a user