From 3760175440f341b3d8ae0bf773879e2054397ecb Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 21 Feb 2025 19:01:35 +0000 Subject: [PATCH] adding runtime metrics (cpu + gpu memory, steps/s, etc.) --- src/axolotl/core/trainer_builder.py | 12 +- src/axolotl/telemetry/callbacks.py | 157 +++++++++++++++++++++ src/axolotl/telemetry/manager.py | 2 +- src/axolotl/telemetry/runtime_metrics.py | 167 +++++++++++++++++++++++ src/axolotl/train.py | 18 +-- 5 files changed, 338 insertions(+), 18 deletions(-) create mode 100644 src/axolotl/telemetry/callbacks.py create mode 100644 src/axolotl/telemetry/runtime_metrics.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 12346b8a2..311bfe667 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -61,6 +61,8 @@ from axolotl.core.training_args import ( from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback +from axolotl.telemetry.callbacks import TelemetryCallback +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, @@ -176,10 +178,8 @@ class TrainerBuilderBase(abc.ABC): SaveAxolotlConfigtoMlflowCallback, ) - callbacks.extend( - [ - SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path), - ] + callbacks.append( + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) ) if self.cfg.use_comet and is_comet_available(): from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback @@ -188,6 +188,10 @@ class TrainerBuilderBase(abc.ABC): SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + telemetry_manager = TelemetryManager.get_instance() + if telemetry_manager.enabled: + callbacks.append(TelemetryCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/telemetry/callbacks.py b/src/axolotl/telemetry/callbacks.py new file mode 100644 index 000000000..cbafe2086 --- /dev/null +++ b/src/axolotl/telemetry/callbacks.py @@ -0,0 +1,157 @@ +"""Trainer callbacks for reporting runtime metrics at regular intervals.""" + +import logging +import time + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.telemetry.manager import TelemetryManager +from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker + +LOG = logging.getLogger(__name__) + +TIME_SINCE_LAST = 30 + + +class TelemetryCallback(TrainerCallback): + """ + Trainer callback for tracking and reporting runtime metrics. + + This callback tracks training progress, runtime, and memory usage, + sending telemetry at configurable intervals. + """ + + report_interval_steps: int = 100 + + def __init__(self): + """Initialize the metrics callback.""" + self.tracker = RuntimeMetricsTracker() + self.telemetry_manager = TelemetryManager.get_instance() + self.current_epoch = -1 + self.start_time = time.time() + self.last_report_time = self.start_time + self.last_report_step = 0 + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + """Handle training start.""" + self.telemetry_manager.send_event(event_type="train-start") + + def on_train_end( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + """Handle training end.""" + # Send training completion event + self.telemetry_manager.send_event( + event_type="train-end", + properties={ + "loss": state.log_history[-1].get("loss", 0) + if state.log_history + else None, + "learning_rate": state.log_history[-1].get("learning_rate", 0) + if state.log_history + else None, + } + | self.tracker.metrics.to_dict(), + ) + + def on_epoch_begin( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + """Handle epoch start.""" + self.current_epoch += 1 + self.tracker.start_epoch(self.current_epoch) + + def on_epoch_end( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + """Handle epoch end.""" + self.tracker.end_epoch(self.current_epoch) + + def on_step_end( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + """Handle step end.""" + step = state.global_step + self.tracker.update_step(step) + + # Check if we should report metrics + should_report = ( + step % self.report_interval_steps == 0 + or step == 1 # Always report first step + or step - self.last_report_step >= self.report_interval_steps + ) + + if should_report: + current_time = time.time() + time_since_last_report = current_time - self.last_report_time + steps_since_last_report = step - self.last_report_step + + # Only report if enough time has passed to avoid flooding + if ( + time_since_last_report >= TIME_SINCE_LAST + or steps_since_last_report >= self.report_interval_steps + ): + # Calculate steps per second for this interval + if time_since_last_report > 0 and steps_since_last_report > 0: + steps_per_second = steps_since_last_report / time_since_last_report + else: + steps_per_second = 0 + + # Update memory metrics + self.tracker.update_memory_metrics() + + # Prepare metrics to report + metrics = { + "step": step, + "epoch": self.current_epoch, + "progress": state.epoch, # Fractional epoch progress + "loss": state.log_history[-1].get("loss", 0) + if state.log_history + else 0, + "learning_rate": state.log_history[-1].get("learning_rate", 0) + if state.log_history + else 0, + "steps_per_second": steps_per_second, + "elapsed_time": current_time - self.start_time, + "time_since_last_report": time_since_last_report, + } + + # Add memory metrics + memory_metrics = self.tracker.get_memory_metrics() + metrics.update(memory_metrics) + + # Send telemetry + self.telemetry_manager.send_event( + event_type="train-progress", properties=metrics + ) + + # Update last report time and step + self.last_report_time = current_time + self.last_report_step = step diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index 0b5e2933e..e4a6d0bea 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -141,7 +141,7 @@ class TelemetryManager: return enabled, explicit_enabled def _load_whitelist(self) -> dict: - """Load organization/model whitelist""" + """Load HuggingFace Hub organization whitelist""" with open(self.config.whitelist_path, encoding="utf-8") as f: return yaml.safe_load(f) diff --git a/src/axolotl/telemetry/runtime_metrics.py b/src/axolotl/telemetry/runtime_metrics.py new file mode 100644 index 000000000..d0f52b88b --- /dev/null +++ b/src/axolotl/telemetry/runtime_metrics.py @@ -0,0 +1,167 @@ +"""Telemetry utilities for runtime and memory metrics.""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +import psutil +import torch + +from axolotl.telemetry.manager import TelemetryManager + +LOG = logging.getLogger(__name__) + + +@dataclass +class RuntimeMetrics: + """Container for runtime metrics to be tracked throughout training.""" + + # Timing metrics + start_time: float + epoch_start_times: dict[int, float] = field(init=False) + epoch_end_times: dict[int, float] = field(init=False) + + # Memory metrics + peak_cpu_memory: int = 0 + peak_gpu_memory: dict[int, int] = field(init=False) + + # Progress metrics + total_steps: int = 0 + current_epoch: int = 0 + current_step: int = 0 + + def __post_init__(self): + """Initialize empty metric mappings.""" + self.epoch_start_times = {} + self.epoch_end_times = {} + self.peak_gpu_memory = {} + + @property + def elapsed_time(self) -> float: + """Calculate total elapsed time in seconds.""" + return time.time() - self.start_time + + def epoch_time(self, epoch: int) -> float | None: + """Calculate time taken for a specific epoch in seconds.""" + if epoch in self.epoch_start_times and epoch in self.epoch_end_times: + return self.epoch_end_times[epoch] - self.epoch_start_times[epoch] + + return None + + def average_epoch_time(self) -> float | None: + """Calculate average time per epoch in seconds.""" + completed_epochs = [ + epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times + ] + if not completed_epochs: + return None + + total_time = 0.0 + for epoch in completed_epochs: + epoch_time = self.epoch_time(epoch) + if epoch_time is not None: # Check to avoid mypy warning + total_time += epoch_time + + return total_time / len(completed_epochs) + + def steps_per_second(self) -> float | None: + """Calculate average steps per second across all training.""" + if self.total_steps == 0 or self.elapsed_time == 0: + return None + + return self.total_steps / self.elapsed_time + + def to_dict(self) -> dict[str, Any]: + """Convert metrics to a dictionary for telemetry reporting.""" + metrics = { + "total_time_seconds": self.elapsed_time, + "total_steps": self.total_steps, + "steps_per_second": self.steps_per_second(), + "epochs_completed": len( + [ + epoch + for epoch in self.epoch_start_times + if epoch in self.epoch_end_times + ] + ), + "peak_cpu_memory_bytes": self.peak_cpu_memory, + } + + # Add per-epoch timing if available + epoch_times: dict[str, float] = {} + for epoch in sorted(self.epoch_end_times.keys()): + time_taken = self.epoch_time(epoch) + if time_taken is not None: + epoch_times[f"epoch_{epoch}_seconds"] = time_taken + + if epoch_times: + metrics["epoch_times"] = epoch_times # type: ignore + metrics["average_epoch_time_seconds"] = self.average_epoch_time() + + # Add GPU memory metrics if available + if self.peak_gpu_memory: + gpu_metrics: dict[str, int] = {} + for gpu_id, memory in self.peak_gpu_memory.items(): + gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory + metrics["gpu_memory"] = gpu_metrics # type: ignore + + return metrics + + +class RuntimeMetricsTracker: + """Tracker for runtime metrics during training.""" + + def __init__(self): + """Initialize the runtime metrics tracker.""" + self.metrics = RuntimeMetrics(start_time=time.time()) + self.telemetry_manager = TelemetryManager.get_instance() + + def start_epoch(self, epoch: int): + """Record the start of a new epoch.""" + self.metrics.current_epoch = epoch + self.metrics.epoch_start_times[epoch] = time.time() + self.update_memory_metrics() + + def end_epoch(self, epoch: int): + """Record the end of an epoch.""" + self.metrics.epoch_end_times[epoch] = time.time() + + def update_step(self, step: int): + """Update the current step count.""" + self.metrics.current_step = step + self.metrics.total_steps += 1 + + # Periodically update memory metrics (e.g., every 100 steps) + if step % 100 == 0: + self.update_memory_metrics() + + def update_memory_metrics(self): + """Update peak memory usage metrics.""" + # CPU memory + cpu_memory = psutil.Process().memory_info().rss + self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory) + + # GPU memory if available + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + memory_used = torch.cuda.memory_allocated(i) + self.metrics.peak_gpu_memory[i] = max( + self.metrics.peak_gpu_memory.get(i, 0), memory_used + ) + + def get_memory_metrics(self) -> dict[str, Any]: + """Get the current memory metrics as a dictionary.""" + memory_metrics = { + "cpu_memory_bytes": psutil.Process().memory_info().rss, + "peak_cpu_memory_bytes": self.metrics.peak_cpu_memory, + } + + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + memory_metrics[f"gpu_{i}_memory_bytes"] = torch.cuda.memory_allocated(i) + memory_metrics[ + f"gpu_{i}_peak_memory_bytes" + ] = self.metrics.peak_gpu_memory.get(i, 0) + + return {"memory": memory_metrics} diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 2aa1fa9fe..943b58f22 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -89,14 +89,13 @@ def train( if model.generation_config is not None: model.generation_config.do_sample = True - if TELEMETRY_MANAGER.enabled: + TELEMETRY_MANAGER.send_event( + event_type="model-load", properties=model.config.to_dict() + ) + if peft_config: TELEMETRY_MANAGER.send_event( - event_type="model-load", properties=model.config.to_dict() + event_type="peft-config-load", properties=peft_config.to_dict() ) - if peft_config: - TELEMETRY_MANAGER.send_event( - event_type="peft-config-load", properties=peft_config.to_dict() - ) model_ref = None if cfg.rl and cfg.rl != "orpo": @@ -188,9 +187,6 @@ def train( if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") - if TELEMETRY_MANAGER.enabled: - TELEMETRY_MANAGER.send_event(event_type="train-start") - if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... @@ -201,10 +197,6 @@ def train( trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) - - if TELEMETRY_MANAGER.enabled: - TELEMETRY_MANAGER.send_event(event_type="train-end") - LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") # post training