Feat/opentelemetry (#3215)
This commit is contained in:
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/opentelemetry-example
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
# OpenTelemetry Configuration
|
||||||
|
use_otel_metrics: true
|
||||||
|
otel_metrics_host: "localhost"
|
||||||
|
otel_metrics_port: 8000
|
||||||
|
|
||||||
|
# Disable WandB
|
||||||
|
use_wandb: false
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_32bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
6
setup.py
6
setup.py
@@ -159,6 +159,12 @@ extras_require = {
|
|||||||
"llmcompressor==0.5.1",
|
"llmcompressor==0.5.1",
|
||||||
],
|
],
|
||||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||||
|
"opentelemetry": [
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-exporter-prometheus",
|
||||||
|
"prometheus-client",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
extras_require
|
extras_require
|
||||||
|
|||||||
@@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig
|
|||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import (
|
||||||
|
is_comet_available,
|
||||||
|
is_mlflow_available,
|
||||||
|
is_opentelemetry_available,
|
||||||
|
)
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GCCallback,
|
GCCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
@@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
|
||||||
if self.cfg.save_first_step:
|
if self.cfg.save_first_step:
|
||||||
callbacks.append(SaveModelOnFirstStepCallback())
|
callbacks.append(SaveModelOnFirstStepCallback())
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,13 @@ def is_comet_available():
|
|||||||
return importlib.util.find_spec("comet_ml") is not None
|
return importlib.util.find_spec("comet_ml") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_opentelemetry_available():
|
||||||
|
return (
|
||||||
|
importlib.util.find_spec("opentelemetry") is not None
|
||||||
|
and importlib.util.find_spec("prometheus_client") is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_version() -> tuple[int, int, int]:
|
def get_pytorch_version() -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Get Pytorch version as a tuple of (major, minor, patch).
|
Get Pytorch version as a tuple of (major, minor, patch).
|
||||||
|
|||||||
238
src/axolotl/utils/callbacks/opentelemetry.py
Normal file
238
src/axolotl/utils/callbacks/opentelemetry.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""OpenTelemetry metrics callback for Axolotl training"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from opentelemetry import metrics
|
||||||
|
from opentelemetry.exporter.prometheus import PrometheusMetricReader
|
||||||
|
from opentelemetry.metrics import set_meter_provider
|
||||||
|
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
|
||||||
|
from prometheus_client import start_http_server
|
||||||
|
|
||||||
|
OPENTELEMETRY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
|
||||||
|
OPENTELEMETRY_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryMetricsCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
|
||||||
|
|
||||||
|
This callback automatically tracks key training metrics including:
|
||||||
|
- Training loss
|
||||||
|
- Evaluation loss
|
||||||
|
- Learning rate
|
||||||
|
- Epoch progress
|
||||||
|
- Global step count
|
||||||
|
- Gradient norm
|
||||||
|
|
||||||
|
Metrics are exposed via HTTP endpoint for Prometheus scraping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
LOG.warning("OpenTelemetry not available, metrics will not be collected")
|
||||||
|
self.metrics_enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
|
||||||
|
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
|
||||||
|
self.metrics_enabled = True
|
||||||
|
self.server_started = False
|
||||||
|
self.metrics_lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create Prometheus metrics reader
|
||||||
|
prometheus_reader = PrometheusMetricReader()
|
||||||
|
|
||||||
|
# Create meter provider with Prometheus exporter
|
||||||
|
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
|
||||||
|
set_meter_provider(provider)
|
||||||
|
|
||||||
|
# Get meter for creating metrics
|
||||||
|
self.meter = metrics.get_meter("axolotl.training")
|
||||||
|
|
||||||
|
# Create metrics
|
||||||
|
self._create_metrics()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
|
||||||
|
self.metrics_enabled = False
|
||||||
|
|
||||||
|
def _create_metrics(self):
|
||||||
|
"""Create all metrics that will be tracked"""
|
||||||
|
self.train_loss_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_train_loss",
|
||||||
|
description="Current training loss",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.eval_loss_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_eval_loss",
|
||||||
|
description="Current evaluation loss",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.learning_rate_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_learning_rate",
|
||||||
|
description="Current learning rate",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.epoch_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_epoch",
|
||||||
|
description="Current training epoch",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.global_step_counter = self.meter.create_counter(
|
||||||
|
name="axolotl_global_steps",
|
||||||
|
description="Total training steps completed",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grad_norm_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_gradient_norm",
|
||||||
|
description="Gradient norm",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.memory_usage_gauge = self.meter.create_gauge(
|
||||||
|
name="axolotl_memory_usage",
|
||||||
|
description="Current memory usage in MB",
|
||||||
|
unit="MB",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _start_metrics_server(self):
|
||||||
|
"""Start the HTTP server for metrics exposure"""
|
||||||
|
if self.server_started:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_http_server(self.metrics_port, addr=self.metrics_host)
|
||||||
|
self.server_started = True
|
||||||
|
LOG.info(
|
||||||
|
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the beginning of training"""
|
||||||
|
if not self.metrics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._start_metrics_server()
|
||||||
|
LOG.info("OpenTelemetry metrics collection started")
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
logs: Optional[Dict[str, float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called when logging occurs"""
|
||||||
|
if not self.metrics_enabled or not logs:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "loss" in logs:
|
||||||
|
self.train_loss_gauge.set(logs["loss"])
|
||||||
|
|
||||||
|
if "eval_loss" in logs:
|
||||||
|
self.eval_loss_gauge.set(logs["eval_loss"])
|
||||||
|
|
||||||
|
if "learning_rate" in logs:
|
||||||
|
self.learning_rate_gauge.set(logs["learning_rate"])
|
||||||
|
|
||||||
|
if "epoch" in logs:
|
||||||
|
self.epoch_gauge.set(logs["epoch"])
|
||||||
|
|
||||||
|
if "grad_norm" in logs:
|
||||||
|
self.grad_norm_gauge.set(logs["grad_norm"])
|
||||||
|
if "memory_usage" in logs:
|
||||||
|
self.memory_usage_gauge.set(logs["memory_usage"])
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the end of each training step"""
|
||||||
|
if not self.metrics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update step counter and epoch
|
||||||
|
self.global_step_counter.add(1)
|
||||||
|
if state.epoch is not None:
|
||||||
|
self.epoch_gauge.set(state.epoch)
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
metrics: Optional[Dict[str, float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called after evaluation"""
|
||||||
|
if not self.metrics_enabled or not metrics:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "eval_loss" in metrics:
|
||||||
|
self.eval_loss_gauge.set(metrics["eval_loss"])
|
||||||
|
|
||||||
|
# Record any other eval metrics as gauges
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if key.startswith("eval_") and isinstance(value, (int, float)):
|
||||||
|
# Create gauge for this metric if it doesn't exist
|
||||||
|
gauge_name = f"axolotl_{key}"
|
||||||
|
try:
|
||||||
|
gauge = self.meter.create_gauge(
|
||||||
|
name=gauge_name,
|
||||||
|
description=f"Evaluation metric: {key}",
|
||||||
|
unit="1",
|
||||||
|
)
|
||||||
|
gauge.set(value)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
|
||||||
|
|
||||||
|
def on_train_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the end of training"""
|
||||||
|
if not self.metrics_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
|
||||||
|
LOG.info(
|
||||||
|
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||||
|
)
|
||||||
@@ -30,6 +30,7 @@ from axolotl.utils.schemas.integrations import (
|
|||||||
GradioConfig,
|
GradioConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
|
OpenTelemetryConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
)
|
)
|
||||||
@@ -60,6 +61,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
CometConfig,
|
CometConfig,
|
||||||
|
OpenTelemetryConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
|
|||||||
@@ -176,3 +176,27 @@ class RayConfig(BaseModel):
|
|||||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTelemetryConfig(BaseModel):
|
||||||
|
"""OpenTelemetry configuration subset"""
|
||||||
|
|
||||||
|
use_otel_metrics: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable OpenTelemetry metrics collection and Prometheus export"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
otel_metrics_host: str | None = Field(
|
||||||
|
default="localhost",
|
||||||
|
json_schema_extra={
|
||||||
|
"title": "OpenTelemetry Metrics Host",
|
||||||
|
"description": "Host to bind the OpenTelemetry metrics server to",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
otel_metrics_port: int | None = Field(
|
||||||
|
default=8000,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Port for the Prometheus metrics HTTP server"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
349
tests/test_opentelemetry_callback.py
Normal file
349
tests/test_opentelemetry_callback.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
"""Tests for OpenTelemetry metrics callback functionality."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_otel_config():
|
||||||
|
"""Mock configuration for OpenTelemetry callback."""
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"use_otel_metrics": True,
|
||||||
|
"otel_metrics_host": "localhost",
|
||||||
|
"otel_metrics_port": 8003, # Use unique port for tests
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trainer_state():
|
||||||
|
"""Mock trainer state for callback testing."""
|
||||||
|
from transformers import TrainerState
|
||||||
|
|
||||||
|
state = TrainerState()
|
||||||
|
state.epoch = 1.0
|
||||||
|
state.global_step = 100
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_training_args():
|
||||||
|
"""Mock training arguments for callback testing."""
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
return TrainingArguments(output_dir="/tmp/test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trainer_control():
|
||||||
|
"""Mock trainer control for callback testing."""
|
||||||
|
from transformers.trainer_callback import TrainerControl
|
||||||
|
|
||||||
|
return TrainerControl()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryConfig:
|
||||||
|
"""Test OpenTelemetry configuration schema."""
|
||||||
|
|
||||||
|
def test_config_schema_valid(self):
|
||||||
|
"""Test OpenTelemetry configuration schema validation."""
|
||||||
|
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||||
|
|
||||||
|
# Test valid config
|
||||||
|
valid_config = {
|
||||||
|
"use_otel_metrics": True,
|
||||||
|
"otel_metrics_host": "localhost",
|
||||||
|
"otel_metrics_port": 8000,
|
||||||
|
}
|
||||||
|
|
||||||
|
otel_config = OpenTelemetryConfig(**valid_config)
|
||||||
|
assert otel_config.use_otel_metrics is True
|
||||||
|
assert otel_config.otel_metrics_host == "localhost"
|
||||||
|
assert otel_config.otel_metrics_port == 8000
|
||||||
|
|
||||||
|
def test_config_defaults(self):
|
||||||
|
"""Test OpenTelemetry configuration default values."""
|
||||||
|
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||||
|
|
||||||
|
# Test minimal config with defaults
|
||||||
|
minimal_config = {"use_otel_metrics": True}
|
||||||
|
|
||||||
|
otel_config = OpenTelemetryConfig(**minimal_config)
|
||||||
|
assert otel_config.use_otel_metrics is True
|
||||||
|
assert otel_config.otel_metrics_host == "localhost" # default
|
||||||
|
assert otel_config.otel_metrics_port == 8000 # default
|
||||||
|
|
||||||
|
def test_config_disabled_by_default(self):
|
||||||
|
"""Test that OpenTelemetry is disabled by default."""
|
||||||
|
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||||
|
|
||||||
|
# Test default config
|
||||||
|
default_config = OpenTelemetryConfig()
|
||||||
|
assert default_config.use_otel_metrics is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryCallback:
|
||||||
|
"""Test OpenTelemetry callback functionality."""
|
||||||
|
|
||||||
|
def test_callback_import(self):
|
||||||
|
"""Test that OpenTelemetry callback can be imported."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||||
|
|
||||||
|
assert OpenTelemetryMetricsCallback is not None
|
||||||
|
|
||||||
|
def test_callback_graceful_fallback(self, mock_otel_config):
|
||||||
|
"""Test callback gracefully handles missing dependencies."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||||
|
|
||||||
|
# This should not raise an exception even if dependencies are missing
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
|
||||||
|
# Callback should exist but may have metrics disabled
|
||||||
|
assert callback is not None
|
||||||
|
assert hasattr(callback, "metrics_enabled")
|
||||||
|
|
||||||
|
def test_callback_initialization_enabled(self, mock_otel_config):
|
||||||
|
"""Test callback initialization when OpenTelemetry is available."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
|
||||||
|
if OPENTELEMETRY_AVAILABLE:
|
||||||
|
assert callback.metrics_enabled is True
|
||||||
|
assert callback.cfg == mock_otel_config
|
||||||
|
assert callback.metrics_host == "localhost"
|
||||||
|
assert callback.metrics_port == 8003
|
||||||
|
else:
|
||||||
|
assert callback.metrics_enabled is False
|
||||||
|
|
||||||
|
def test_metrics_server_lifecycle(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test metrics server starts and stops correctly."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
assert callback.server_started is True
|
||||||
|
|
||||||
|
# End training
|
||||||
|
callback.on_train_end(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_metrics_recording(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test that metrics are recorded during training."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test logging metrics
|
||||||
|
test_logs = {
|
||||||
|
"loss": 0.5,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"grad_norm": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should not raise an exception
|
||||||
|
callback.on_log(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs
|
||||||
|
)
|
||||||
|
assert callback.metrics_enabled is True
|
||||||
|
|
||||||
|
def test_evaluation_metrics(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test evaluation metrics recording."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test evaluation metrics
|
||||||
|
eval_logs = {
|
||||||
|
"eval_loss": 0.3,
|
||||||
|
"eval_accuracy": 0.95,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should not raise an exception
|
||||||
|
callback.on_evaluate(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs
|
||||||
|
)
|
||||||
|
assert callback.metrics_enabled is True
|
||||||
|
|
||||||
|
def test_thread_safety(self, mock_otel_config):
|
||||||
|
"""Test that callback has thread safety mechanisms."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
assert hasattr(callback, "metrics_lock")
|
||||||
|
# Check it's a lock-like object
|
||||||
|
assert hasattr(callback.metrics_lock, "__enter__")
|
||||||
|
assert hasattr(callback.metrics_lock, "__exit__")
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryIntegration:
|
||||||
|
"""Integration tests for OpenTelemetry."""
|
||||||
|
|
||||||
|
def test_availability_check(self):
|
||||||
|
"""Test availability check function."""
|
||||||
|
from axolotl.utils import is_opentelemetry_available
|
||||||
|
|
||||||
|
result = is_opentelemetry_available()
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
|
||||||
|
def test_prometheus_endpoint_basic(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test basic Prometheus endpoint functionality."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("requests library not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
if not callback.server_started:
|
||||||
|
pytest.skip("Metrics server failed to start")
|
||||||
|
|
||||||
|
# Give server time to start
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Try to access metrics endpoint
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"http://{callback.metrics_host}:{callback.metrics_port}/metrics",
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Check for Prometheus format
|
||||||
|
assert "# TYPE" in response.text or "# HELP" in response.text
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
pytest.skip(
|
||||||
|
"Could not connect to metrics endpoint - this is expected in some environments"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenTelemetryCallbackMethods:
|
||||||
|
"""Test specific callback methods."""
|
||||||
|
|
||||||
|
def test_step_end_callback(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test step end callback method."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
callback.on_step_end(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_epoch_end_callback(
|
||||||
|
self,
|
||||||
|
mock_otel_config,
|
||||||
|
mock_trainer_state,
|
||||||
|
mock_training_args,
|
||||||
|
mock_trainer_control,
|
||||||
|
):
|
||||||
|
"""Test epoch end callback method."""
|
||||||
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
|
OPENTELEMETRY_AVAILABLE,
|
||||||
|
OpenTelemetryMetricsCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not OPENTELEMETRY_AVAILABLE:
|
||||||
|
pytest.skip("OpenTelemetry dependencies not available")
|
||||||
|
|
||||||
|
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||||
|
callback.on_train_begin(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise an exception
|
||||||
|
callback.on_epoch_end(
|
||||||
|
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user