Merge branch 'main' into telemetry-opt-in
This commit is contained in:
@@ -546,7 +546,6 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip("regression failure from v4.57.0")
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="argilla_chat_dataset")
|
||||
def fixture_argilla_chat_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"chosen": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
],
|
||||
"rejected": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "party on",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="phi3_tokenizer")
|
||||
@enable_hf_offline
|
||||
def fixture_phi3_tokenizer():
|
||||
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
assert result["rejected"] == "party on<end_of_turn>"
|
||||
|
||||
|
||||
class TestArgillaChatDPOChatTemplate:
|
||||
"""
|
||||
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
|
||||
"""
|
||||
|
||||
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
|
||||
transform_fn, _ = argilla_chat(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template.argilla_chat",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<|begin_of_text|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
|
||||
transform_fn, _ = argilla_chat(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template.argilla_chat",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
||||
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
||||
assert result["chosen"] == "goodbye<|end|>"
|
||||
assert result["rejected"] == "party on<|end|>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -80,16 +80,26 @@ class TestModelsUtils:
|
||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||
)
|
||||
elif load_in_8bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
||||
elif load_in_4bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
||||
|
||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
||||
self.cfg.adapter == "lora" and load_in_8bit
|
||||
):
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
if self.cfg.adapter == "qlora" and load_in_4bit:
|
||||
assert isinstance(
|
||||
self.model_loader.model_kwargs.get("quantization_config"),
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
|
||||
is True
|
||||
)
|
||||
if self.cfg.adapter == "lora" and load_in_8bit:
|
||||
assert isinstance(
|
||||
self.model_loader.model_kwargs.get("quantization_config"),
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
|
||||
is True
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
|
||||
349
tests/test_opentelemetry_callback.py
Normal file
349
tests/test_opentelemetry_callback.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Tests for OpenTelemetry metrics callback functionality."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_otel_config():
|
||||
"""Mock configuration for OpenTelemetry callback."""
|
||||
return DictDefault(
|
||||
{
|
||||
"use_otel_metrics": True,
|
||||
"otel_metrics_host": "localhost",
|
||||
"otel_metrics_port": 8003, # Use unique port for tests
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer_state():
|
||||
"""Mock trainer state for callback testing."""
|
||||
from transformers import TrainerState
|
||||
|
||||
state = TrainerState()
|
||||
state.epoch = 1.0
|
||||
state.global_step = 100
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_training_args():
|
||||
"""Mock training arguments for callback testing."""
|
||||
from transformers import TrainingArguments
|
||||
|
||||
return TrainingArguments(output_dir="/tmp/test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer_control():
|
||||
"""Mock trainer control for callback testing."""
|
||||
from transformers.trainer_callback import TrainerControl
|
||||
|
||||
return TrainerControl()
|
||||
|
||||
|
||||
class TestOpenTelemetryConfig:
|
||||
"""Test OpenTelemetry configuration schema."""
|
||||
|
||||
def test_config_schema_valid(self):
|
||||
"""Test OpenTelemetry configuration schema validation."""
|
||||
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||
|
||||
# Test valid config
|
||||
valid_config = {
|
||||
"use_otel_metrics": True,
|
||||
"otel_metrics_host": "localhost",
|
||||
"otel_metrics_port": 8000,
|
||||
}
|
||||
|
||||
otel_config = OpenTelemetryConfig(**valid_config)
|
||||
assert otel_config.use_otel_metrics is True
|
||||
assert otel_config.otel_metrics_host == "localhost"
|
||||
assert otel_config.otel_metrics_port == 8000
|
||||
|
||||
def test_config_defaults(self):
|
||||
"""Test OpenTelemetry configuration default values."""
|
||||
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||
|
||||
# Test minimal config with defaults
|
||||
minimal_config = {"use_otel_metrics": True}
|
||||
|
||||
otel_config = OpenTelemetryConfig(**minimal_config)
|
||||
assert otel_config.use_otel_metrics is True
|
||||
assert otel_config.otel_metrics_host == "localhost" # default
|
||||
assert otel_config.otel_metrics_port == 8000 # default
|
||||
|
||||
def test_config_disabled_by_default(self):
|
||||
"""Test that OpenTelemetry is disabled by default."""
|
||||
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||
|
||||
# Test default config
|
||||
default_config = OpenTelemetryConfig()
|
||||
assert default_config.use_otel_metrics is False
|
||||
|
||||
|
||||
class TestOpenTelemetryCallback:
|
||||
"""Test OpenTelemetry callback functionality."""
|
||||
|
||||
def test_callback_import(self):
|
||||
"""Test that OpenTelemetry callback can be imported."""
|
||||
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||
|
||||
assert OpenTelemetryMetricsCallback is not None
|
||||
|
||||
def test_callback_graceful_fallback(self, mock_otel_config):
|
||||
"""Test callback gracefully handles missing dependencies."""
|
||||
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||
|
||||
# This should not raise an exception even if dependencies are missing
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
|
||||
# Callback should exist but may have metrics disabled
|
||||
assert callback is not None
|
||||
assert hasattr(callback, "metrics_enabled")
|
||||
|
||||
def test_callback_initialization_enabled(self, mock_otel_config):
|
||||
"""Test callback initialization when OpenTelemetry is available."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
|
||||
if OPENTELEMETRY_AVAILABLE:
|
||||
assert callback.metrics_enabled is True
|
||||
assert callback.cfg == mock_otel_config
|
||||
assert callback.metrics_host == "localhost"
|
||||
assert callback.metrics_port == 8003
|
||||
else:
|
||||
assert callback.metrics_enabled is False
|
||||
|
||||
def test_metrics_server_lifecycle(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test metrics server starts and stops correctly."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
|
||||
# Start server
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
assert callback.server_started is True
|
||||
|
||||
# End training
|
||||
callback.on_train_end(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
def test_metrics_recording(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test that metrics are recorded during training."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Test logging metrics
|
||||
test_logs = {
|
||||
"loss": 0.5,
|
||||
"learning_rate": 1e-4,
|
||||
"grad_norm": 0.8,
|
||||
}
|
||||
|
||||
# This should not raise an exception
|
||||
callback.on_log(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs
|
||||
)
|
||||
assert callback.metrics_enabled is True
|
||||
|
||||
def test_evaluation_metrics(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test evaluation metrics recording."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Test evaluation metrics
|
||||
eval_logs = {
|
||||
"eval_loss": 0.3,
|
||||
"eval_accuracy": 0.95,
|
||||
}
|
||||
|
||||
# This should not raise an exception
|
||||
callback.on_evaluate(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs
|
||||
)
|
||||
assert callback.metrics_enabled is True
|
||||
|
||||
def test_thread_safety(self, mock_otel_config):
|
||||
"""Test that callback has thread safety mechanisms."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
assert hasattr(callback, "metrics_lock")
|
||||
# Check it's a lock-like object
|
||||
assert hasattr(callback.metrics_lock, "__enter__")
|
||||
assert hasattr(callback.metrics_lock, "__exit__")
|
||||
|
||||
|
||||
class TestOpenTelemetryIntegration:
|
||||
"""Integration tests for OpenTelemetry."""
|
||||
|
||||
def test_availability_check(self):
|
||||
"""Test availability check function."""
|
||||
from axolotl.utils import is_opentelemetry_available
|
||||
|
||||
result = is_opentelemetry_available()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_prometheus_endpoint_basic(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test basic Prometheus endpoint functionality."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
pytest.skip("requests library not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
if not callback.server_started:
|
||||
pytest.skip("Metrics server failed to start")
|
||||
|
||||
# Give server time to start
|
||||
time.sleep(1)
|
||||
|
||||
# Try to access metrics endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"http://{callback.metrics_host}:{callback.metrics_port}/metrics",
|
||||
timeout=2,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Check for Prometheus format
|
||||
assert "# TYPE" in response.text or "# HELP" in response.text
|
||||
except requests.exceptions.RequestException:
|
||||
pytest.skip(
|
||||
"Could not connect to metrics endpoint - this is expected in some environments"
|
||||
)
|
||||
|
||||
|
||||
class TestOpenTelemetryCallbackMethods:
|
||||
"""Test specific callback methods."""
|
||||
|
||||
def test_step_end_callback(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test step end callback method."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
callback.on_step_end(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
def test_epoch_end_callback(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test epoch end callback method."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
callback.on_epoch_end(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
Reference in New Issue
Block a user