From 4dc018992dccba6fa5e239d0453cbbd565e47e96 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 23 Oct 2025 07:46:55 +0530 Subject: [PATCH] Feat/opentelemetry (#3215) --- examples/llama-3/opentelemetry-qlora.yml | 50 +++ setup.py | 6 + src/axolotl/core/builders/base.py | 12 +- src/axolotl/utils/__init__.py | 7 + src/axolotl/utils/callbacks/opentelemetry.py | 238 +++++++++++++ src/axolotl/utils/schemas/config.py | 2 + src/axolotl/utils/schemas/integrations.py | 24 ++ tests/test_opentelemetry_callback.py | 349 +++++++++++++++++++ 8 files changed, 687 insertions(+), 1 deletion(-) create mode 100644 examples/llama-3/opentelemetry-qlora.yml create mode 100644 src/axolotl/utils/callbacks/opentelemetry.py create mode 100644 tests/test_opentelemetry_callback.py diff --git a/examples/llama-3/opentelemetry-qlora.yml b/examples/llama-3/opentelemetry-qlora.yml new file mode 100644 index 000000000..d8ce7b1ec --- /dev/null +++ b/examples/llama-3/opentelemetry-qlora.yml @@ -0,0 +1,50 @@ +base_model: NousResearch/Llama-3.2-1B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_4bit: true + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca + +output_dir: ./outputs/opentelemetry-example + +adapter: qlora +sequence_len: 512 +sample_packing: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +# OpenTelemetry Configuration +use_otel_metrics: true +otel_metrics_host: "localhost" +otel_metrics_port: 8000 + +# Disable WandB +use_wandb: false + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: false + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 + +special_tokens: + pad_token: "<|end_of_text|>" diff --git a/setup.py b/setup.py index a93d8d49e..9e3de48b5 100644 --- a/setup.py +++ b/setup.py @@ -159,6 +159,12 @@ extras_require = { "llmcompressor==0.5.1", ], "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"], + "opentelemetry": [ + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-prometheus", + "prometheus-client", + ], } install_requires, dependency_links, extras_require_build = parse_requirements( extras_require diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 8c86e335e..2c949f8e7 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr -from axolotl.utils import is_comet_available, is_mlflow_available +from axolotl.utils import ( + is_comet_available, + is_mlflow_available, + is_opentelemetry_available, +) from axolotl.utils.callbacks import ( GCCallback, SaveAxolotlConfigtoWandBCallback, @@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_otel_metrics and is_opentelemetry_available(): + from axolotl.utils.callbacks.opentelemetry import ( + OpenTelemetryMetricsCallback, + ) + + callbacks.append(OpenTelemetryMetricsCallback(self.cfg)) if self.cfg.save_first_step: callbacks.append(SaveModelOnFirstStepCallback()) diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 7256a5700..72f8173f3 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -17,6 +17,13 @@ def is_comet_available(): return importlib.util.find_spec("comet_ml") is not None +def is_opentelemetry_available(): + return ( + importlib.util.find_spec("opentelemetry") is not None + and importlib.util.find_spec("prometheus_client") is not None + ) + + def get_pytorch_version() -> tuple[int, int, int]: """ Get Pytorch version as a tuple of (major, minor, patch). diff --git a/src/axolotl/utils/callbacks/opentelemetry.py b/src/axolotl/utils/callbacks/opentelemetry.py new file mode 100644 index 000000000..3f7e56b78 --- /dev/null +++ b/src/axolotl/utils/callbacks/opentelemetry.py @@ -0,0 +1,238 @@ +"""OpenTelemetry metrics callback for Axolotl training""" + +import threading +from typing import Dict, Optional + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +try: + from opentelemetry import metrics + from opentelemetry.exporter.prometheus import PrometheusMetricReader + from opentelemetry.metrics import set_meter_provider + from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider + from prometheus_client import start_http_server + + OPENTELEMETRY_AVAILABLE = True +except ImportError: + LOG.warning("OpenTelemetry not available. pip install [opentelemetry]") + OPENTELEMETRY_AVAILABLE = False + + +class OpenTelemetryMetricsCallback(TrainerCallback): + """ + TrainerCallback that exports training metrics to OpenTelemetry/Prometheus. + + This callback automatically tracks key training metrics including: + - Training loss + - Evaluation loss + - Learning rate + - Epoch progress + - Global step count + - Gradient norm + + Metrics are exposed via HTTP endpoint for Prometheus scraping. + """ + + def __init__(self, cfg): + if not OPENTELEMETRY_AVAILABLE: + LOG.warning("OpenTelemetry not available, metrics will not be collected") + self.metrics_enabled = False + return + + self.cfg = cfg + self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost") + self.metrics_port = getattr(cfg, "otel_metrics_port", 8000) + self.metrics_enabled = True + self.server_started = False + self.metrics_lock = threading.Lock() + + try: + # Create Prometheus metrics reader + prometheus_reader = PrometheusMetricReader() + + # Create meter provider with Prometheus exporter + provider = SDKMeterProvider(metric_readers=[prometheus_reader]) + set_meter_provider(provider) + + # Get meter for creating metrics + self.meter = metrics.get_meter("axolotl.training") + + # Create metrics + self._create_metrics() + + except Exception as e: + LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}") + self.metrics_enabled = False + + def _create_metrics(self): + """Create all metrics that will be tracked""" + self.train_loss_gauge = self.meter.create_gauge( + name="axolotl_train_loss", + description="Current training loss", + unit="1", + ) + + self.eval_loss_gauge = self.meter.create_gauge( + name="axolotl_eval_loss", + description="Current evaluation loss", + unit="1", + ) + + self.learning_rate_gauge = self.meter.create_gauge( + name="axolotl_learning_rate", + description="Current learning rate", + unit="1", + ) + + self.epoch_gauge = self.meter.create_gauge( + name="axolotl_epoch", + description="Current training epoch", + unit="1", + ) + + self.global_step_counter = self.meter.create_counter( + name="axolotl_global_steps", + description="Total training steps completed", + unit="1", + ) + + self.grad_norm_gauge = self.meter.create_gauge( + name="axolotl_gradient_norm", + description="Gradient norm", + unit="1", + ) + + self.memory_usage_gauge = self.meter.create_gauge( + name="axolotl_memory_usage", + description="Current memory usage in MB", + unit="MB", + ) + + def _start_metrics_server(self): + """Start the HTTP server for metrics exposure""" + if self.server_started: + return + + try: + start_http_server(self.metrics_port, addr=self.metrics_host) + self.server_started = True + LOG.info( + f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics" + ) + + except Exception as e: + LOG.error(f"Failed to start OpenTelemetry metrics server: {e}") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the beginning of training""" + if not self.metrics_enabled: + return + + self._start_metrics_server() + LOG.info("OpenTelemetry metrics collection started") + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Optional[Dict[str, float]] = None, + **kwargs, + ): + """Called when logging occurs""" + if not self.metrics_enabled or not logs: + return + + if "loss" in logs: + self.train_loss_gauge.set(logs["loss"]) + + if "eval_loss" in logs: + self.eval_loss_gauge.set(logs["eval_loss"]) + + if "learning_rate" in logs: + self.learning_rate_gauge.set(logs["learning_rate"]) + + if "epoch" in logs: + self.epoch_gauge.set(logs["epoch"]) + + if "grad_norm" in logs: + self.grad_norm_gauge.set(logs["grad_norm"]) + if "memory_usage" in logs: + self.memory_usage_gauge.set(logs["memory_usage"]) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the end of each training step""" + if not self.metrics_enabled: + return + + # Update step counter and epoch + self.global_step_counter.add(1) + if state.epoch is not None: + self.epoch_gauge.set(state.epoch) + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + metrics: Optional[Dict[str, float]] = None, + **kwargs, + ): + """Called after evaluation""" + if not self.metrics_enabled or not metrics: + return + + if "eval_loss" in metrics: + self.eval_loss_gauge.set(metrics["eval_loss"]) + + # Record any other eval metrics as gauges + for key, value in metrics.items(): + if key.startswith("eval_") and isinstance(value, (int, float)): + # Create gauge for this metric if it doesn't exist + gauge_name = f"axolotl_{key}" + try: + gauge = self.meter.create_gauge( + name=gauge_name, + description=f"Evaluation metric: {key}", + unit="1", + ) + gauge.set(value) + except Exception as e: + LOG.warning(f"Failed to create/update metric {gauge_name}: {e}") + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the end of training""" + if not self.metrics_enabled: + return + + LOG.info("Training completed. OpenTelemetry metrics collection finished.") + LOG.info( + f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics" + ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4d1d0aab2..86b3aa17b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -30,6 +30,7 @@ from axolotl.utils.schemas.integrations import ( GradioConfig, LISAConfig, MLFlowConfig, + OpenTelemetryConfig, RayConfig, WandbConfig, ) @@ -60,6 +61,7 @@ class AxolotlInputConfig( WandbConfig, MLFlowConfig, CometConfig, + OpenTelemetryConfig, LISAConfig, GradioConfig, RayConfig, diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 7332c7d39..97d675569 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -176,3 +176,27 @@ class RayConfig(BaseModel): "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." }, ) + + +class OpenTelemetryConfig(BaseModel): + """OpenTelemetry configuration subset""" + + use_otel_metrics: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Enable OpenTelemetry metrics collection and Prometheus export" + }, + ) + otel_metrics_host: str | None = Field( + default="localhost", + json_schema_extra={ + "title": "OpenTelemetry Metrics Host", + "description": "Host to bind the OpenTelemetry metrics server to", + }, + ) + otel_metrics_port: int | None = Field( + default=8000, + json_schema_extra={ + "description": "Port for the Prometheus metrics HTTP server" + }, + ) diff --git a/tests/test_opentelemetry_callback.py b/tests/test_opentelemetry_callback.py new file mode 100644 index 000000000..294ff6585 --- /dev/null +++ b/tests/test_opentelemetry_callback.py @@ -0,0 +1,349 @@ +"""Tests for OpenTelemetry metrics callback functionality.""" + +import time + +import pytest + +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def mock_otel_config(): + """Mock configuration for OpenTelemetry callback.""" + return DictDefault( + { + "use_otel_metrics": True, + "otel_metrics_host": "localhost", + "otel_metrics_port": 8003, # Use unique port for tests + } + ) + + +@pytest.fixture +def mock_trainer_state(): + """Mock trainer state for callback testing.""" + from transformers import TrainerState + + state = TrainerState() + state.epoch = 1.0 + state.global_step = 100 + return state + + +@pytest.fixture +def mock_training_args(): + """Mock training arguments for callback testing.""" + from transformers import TrainingArguments + + return TrainingArguments(output_dir="/tmp/test") + + +@pytest.fixture +def mock_trainer_control(): + """Mock trainer control for callback testing.""" + from transformers.trainer_callback import TrainerControl + + return TrainerControl() + + +class TestOpenTelemetryConfig: + """Test OpenTelemetry configuration schema.""" + + def test_config_schema_valid(self): + """Test OpenTelemetry configuration schema validation.""" + from axolotl.utils.schemas.integrations import OpenTelemetryConfig + + # Test valid config + valid_config = { + "use_otel_metrics": True, + "otel_metrics_host": "localhost", + "otel_metrics_port": 8000, + } + + otel_config = OpenTelemetryConfig(**valid_config) + assert otel_config.use_otel_metrics is True + assert otel_config.otel_metrics_host == "localhost" + assert otel_config.otel_metrics_port == 8000 + + def test_config_defaults(self): + """Test OpenTelemetry configuration default values.""" + from axolotl.utils.schemas.integrations import OpenTelemetryConfig + + # Test minimal config with defaults + minimal_config = {"use_otel_metrics": True} + + otel_config = OpenTelemetryConfig(**minimal_config) + assert otel_config.use_otel_metrics is True + assert otel_config.otel_metrics_host == "localhost" # default + assert otel_config.otel_metrics_port == 8000 # default + + def test_config_disabled_by_default(self): + """Test that OpenTelemetry is disabled by default.""" + from axolotl.utils.schemas.integrations import OpenTelemetryConfig + + # Test default config + default_config = OpenTelemetryConfig() + assert default_config.use_otel_metrics is False + + +class TestOpenTelemetryCallback: + """Test OpenTelemetry callback functionality.""" + + def test_callback_import(self): + """Test that OpenTelemetry callback can be imported.""" + from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback + + assert OpenTelemetryMetricsCallback is not None + + def test_callback_graceful_fallback(self, mock_otel_config): + """Test callback gracefully handles missing dependencies.""" + from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback + + # This should not raise an exception even if dependencies are missing + callback = OpenTelemetryMetricsCallback(mock_otel_config) + + # Callback should exist but may have metrics disabled + assert callback is not None + assert hasattr(callback, "metrics_enabled") + + def test_callback_initialization_enabled(self, mock_otel_config): + """Test callback initialization when OpenTelemetry is available.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + + if OPENTELEMETRY_AVAILABLE: + assert callback.metrics_enabled is True + assert callback.cfg == mock_otel_config + assert callback.metrics_host == "localhost" + assert callback.metrics_port == 8003 + else: + assert callback.metrics_enabled is False + + def test_metrics_server_lifecycle( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test metrics server starts and stops correctly.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + + # Start server + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + assert callback.server_started is True + + # End training + callback.on_train_end( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + def test_metrics_recording( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test that metrics are recorded during training.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Test logging metrics + test_logs = { + "loss": 0.5, + "learning_rate": 1e-4, + "grad_norm": 0.8, + } + + # This should not raise an exception + callback.on_log( + mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs + ) + assert callback.metrics_enabled is True + + def test_evaluation_metrics( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test evaluation metrics recording.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Test evaluation metrics + eval_logs = { + "eval_loss": 0.3, + "eval_accuracy": 0.95, + } + + # This should not raise an exception + callback.on_evaluate( + mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs + ) + assert callback.metrics_enabled is True + + def test_thread_safety(self, mock_otel_config): + """Test that callback has thread safety mechanisms.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + assert hasattr(callback, "metrics_lock") + # Check it's a lock-like object + assert hasattr(callback.metrics_lock, "__enter__") + assert hasattr(callback.metrics_lock, "__exit__") + + +class TestOpenTelemetryIntegration: + """Integration tests for OpenTelemetry.""" + + def test_availability_check(self): + """Test availability check function.""" + from axolotl.utils import is_opentelemetry_available + + result = is_opentelemetry_available() + assert isinstance(result, bool) + + def test_prometheus_endpoint_basic( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test basic Prometheus endpoint functionality.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + try: + import requests + except ImportError: + pytest.skip("requests library not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + if not callback.server_started: + pytest.skip("Metrics server failed to start") + + # Give server time to start + time.sleep(1) + + # Try to access metrics endpoint + try: + response = requests.get( + f"http://{callback.metrics_host}:{callback.metrics_port}/metrics", + timeout=2, + ) + assert response.status_code == 200 + # Check for Prometheus format + assert "# TYPE" in response.text or "# HELP" in response.text + except requests.exceptions.RequestException: + pytest.skip( + "Could not connect to metrics endpoint - this is expected in some environments" + ) + + +class TestOpenTelemetryCallbackMethods: + """Test specific callback methods.""" + + def test_step_end_callback( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test step end callback method.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Should not raise an exception + callback.on_step_end( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + def test_epoch_end_callback( + self, + mock_otel_config, + mock_trainer_state, + mock_training_args, + mock_trainer_control, + ): + """Test epoch end callback method.""" + from axolotl.utils.callbacks.opentelemetry import ( + OPENTELEMETRY_AVAILABLE, + OpenTelemetryMetricsCallback, + ) + + if not OPENTELEMETRY_AVAILABLE: + pytest.skip("OpenTelemetry dependencies not available") + + callback = OpenTelemetryMetricsCallback(mock_otel_config) + callback.on_train_begin( + mock_training_args, mock_trainer_state, mock_trainer_control + ) + + # Should not raise an exception + callback.on_epoch_end( + mock_training_args, mock_trainer_state, mock_trainer_control + )