Files
axolotl/tests/integrations/test_swanlab.py
Wing Lian e412370877 roundup_power2_divisions not needed with newer pytorch versions (#3540)
* roundup_power2_divisions not needed with newer pytorch versions

* remove typo

* update qwen3.5 moe 35b-a3b yaml for 5090

* more bug fixes

* fix tests to match updated trainer

* don't use fa2 for hooks test

* reset plugins on the instance

* retry download

* fix references to renamed axolotl_cfg property on trainer

* Fix ref to trainer cfg
2026-03-24 15:40:05 -04:00

1338 lines
49 KiB
Python

# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for SwanLab Integration Plugin.
Tests conflict detection, configuration validation, and multi-logger warnings.
"""
import importlib.util
import logging
import os
import time
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from axolotl.integrations.swanlab.args import SwanLabConfig
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
SWANLAB_INSTALLED = importlib.util.find_spec("swanlab") is not None
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabConfigValidators:
"""Tests for Pydantic field validators in SwanLabConfig."""
def test_valid_swanlab_mode_cloud(self):
"""Test that 'cloud' mode is valid."""
config = SwanLabConfig(swanlab_mode="cloud")
assert config.swanlab_mode == "cloud"
def test_valid_swanlab_mode_local(self):
"""Test that 'local' mode is valid."""
config = SwanLabConfig(swanlab_mode="local")
assert config.swanlab_mode == "local"
def test_valid_swanlab_mode_offline(self):
"""Test that 'offline' mode is valid."""
config = SwanLabConfig(swanlab_mode="offline")
assert config.swanlab_mode == "offline"
def test_valid_swanlab_mode_disabled(self):
"""Test that 'disabled' mode is valid."""
config = SwanLabConfig(swanlab_mode="disabled")
assert config.swanlab_mode == "disabled"
def test_invalid_swanlab_mode(self):
"""Test that invalid mode raises ValueError."""
with pytest.raises(ValidationError) as exc_info:
SwanLabConfig(swanlab_mode="invalid")
error_msg = str(exc_info.value)
assert "Invalid swanlab_mode" in error_msg
assert "cloud" in error_msg
assert "local" in error_msg
assert "offline" in error_msg
assert "disabled" in error_msg
def test_swanlab_mode_none_allowed(self):
"""Test that None mode is allowed (will use default)."""
config = SwanLabConfig(swanlab_mode=None)
assert config.swanlab_mode is None
def test_valid_swanlab_project(self):
"""Test that valid project name is accepted."""
config = SwanLabConfig(swanlab_project="my-project")
assert config.swanlab_project == "my-project"
def test_swanlab_project_none_allowed(self):
"""Test that None project is allowed."""
config = SwanLabConfig(swanlab_project=None)
assert config.swanlab_project is None
def test_empty_swanlab_project_rejected(self):
"""Test that empty string project name is rejected."""
with pytest.raises(ValidationError) as exc_info:
SwanLabConfig(swanlab_project="")
error_msg = str(exc_info.value)
assert "cannot be an empty string" in error_msg
def test_whitespace_only_project_rejected(self):
"""Test that whitespace-only project name is rejected."""
with pytest.raises(ValidationError) as exc_info:
SwanLabConfig(swanlab_project=" ")
error_msg = str(exc_info.value)
assert "cannot be an empty string" in error_msg
def test_use_swanlab_true_requires_project(self):
"""Test that use_swanlab=True requires swanlab_project."""
with pytest.raises(ValidationError) as exc_info:
SwanLabConfig(use_swanlab=True, swanlab_project=None)
error_msg = str(exc_info.value)
assert "swanlab_project" in error_msg.lower()
assert "not set" in error_msg.lower()
def test_use_swanlab_true_with_project_valid(self):
"""Test that use_swanlab=True with project is valid."""
config = SwanLabConfig(use_swanlab=True, swanlab_project="my-project")
assert config.use_swanlab is True
assert config.swanlab_project == "my-project"
def test_use_swanlab_false_no_project_valid(self):
"""Test that use_swanlab=False without project is valid."""
config = SwanLabConfig(use_swanlab=False, swanlab_project=None)
assert config.use_swanlab is False
assert config.swanlab_project is None
def test_use_swanlab_none_no_project_valid(self):
"""Test that use_swanlab=None without project is valid."""
config = SwanLabConfig(use_swanlab=None, swanlab_project=None)
assert config.use_swanlab is None
assert config.swanlab_project is None
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabPluginRegister:
"""Tests for SwanLabPlugin.register() conflict detection."""
def test_register_without_use_swanlab(self):
"""Test that register works when SwanLab is not enabled."""
plugin = SwanLabPlugin()
cfg = {"use_swanlab": False}
# Should not raise
plugin.register(cfg)
def test_register_use_swanlab_missing_project(self):
"""Test that use_swanlab=True without project raises ValueError."""
plugin = SwanLabPlugin()
cfg = {"use_swanlab": True}
with pytest.raises(ValueError) as exc_info:
plugin.register(cfg)
error_msg = str(exc_info.value)
assert "swanlab_project" in error_msg
assert "not set" in error_msg
assert "Solutions" in error_msg
def test_register_use_swanlab_with_project_valid(self):
"""Test that use_swanlab=True with project is valid."""
plugin = SwanLabPlugin()
cfg = {"use_swanlab": True, "swanlab_project": "my-project"}
# Should not raise
plugin.register(cfg)
def test_register_invalid_mode(self):
"""Test that invalid swanlab_mode raises ValueError."""
plugin = SwanLabPlugin()
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"swanlab_mode": "invalid-mode",
}
with pytest.raises(ValueError) as exc_info:
plugin.register(cfg)
error_msg = str(exc_info.value)
assert "Invalid swanlab_mode" in error_msg
assert "cloud" in error_msg
assert "local" in error_msg
def test_register_valid_modes(self):
"""Test that all valid modes are accepted."""
plugin = SwanLabPlugin()
valid_modes = ["cloud", "local", "offline", "disabled"]
for mode in valid_modes:
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"swanlab_mode": mode,
}
# Should not raise
plugin.register(cfg)
def test_register_auto_enable_swanlab(self):
"""Test that providing swanlab_project auto-enables use_swanlab."""
plugin = SwanLabPlugin()
cfg = {"swanlab_project": "my-project"}
plugin.register(cfg)
assert cfg["use_swanlab"] is True
def test_register_cloud_mode_without_api_key_warns(self, caplog):
"""Test that cloud mode without API key logs warning."""
plugin = SwanLabPlugin()
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"swanlab_mode": "cloud",
}
# Clear environment variable to ensure it's not set
with patch.dict(os.environ, {}, clear=True):
with caplog.at_level(logging.WARNING):
plugin.register(cfg)
# Should log warning about missing API key
warning_messages = [record.message for record in caplog.records]
assert any("API key" in msg for msg in warning_messages)
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestMultiLoggerDetection:
"""Tests for multi-logger conflict detection."""
def test_single_logger_no_warning(self, caplog):
"""Test that single logger doesn't trigger warning."""
plugin = SwanLabPlugin()
cfg = {"use_swanlab": True, "swanlab_project": "my-project"}
with caplog.at_level(logging.WARNING):
plugin.register(cfg)
# Should not log multi-logger warning
warning_messages = [record.message for record in caplog.records]
assert not any("Multiple logging tools" in msg for msg in warning_messages)
def test_two_loggers_warning(self, caplog):
"""Test that two loggers trigger warning."""
plugin = SwanLabPlugin()
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"use_wandb": True,
}
with caplog.at_level(logging.WARNING):
plugin.register(cfg)
# Should log multi-logger warning
warning_messages = [record.message for record in caplog.records]
assert any("Multiple logging tools" in msg for msg in warning_messages)
assert any("SwanLab" in msg and "WandB" in msg for msg in warning_messages)
def test_three_loggers_error(self, caplog):
"""Test that three loggers trigger error-level warning."""
plugin = SwanLabPlugin()
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"use_wandb": True,
"use_mlflow": True,
}
with caplog.at_level(logging.ERROR):
plugin.register(cfg)
# Should log error-level warning
error_messages = [
record.message
for record in caplog.records
if record.levelno >= logging.ERROR
]
assert any("logging tools enabled" in msg for msg in error_messages)
def test_multi_logger_with_comet(self, caplog):
"""Test that Comet is detected in multi-logger scenario."""
plugin = SwanLabPlugin()
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"comet_api_key": "test-key",
}
with caplog.at_level(logging.WARNING):
plugin.register(cfg)
# Should detect Comet
warning_messages = [record.message for record in caplog.records]
assert any("Comet" in msg for msg in warning_messages)
def test_multi_logger_with_comet_project(self, caplog):
"""Test that Comet is detected via comet_project_name."""
plugin = SwanLabPlugin()
cfg = {
"use_swanlab": True,
"swanlab_project": "my-project",
"comet_project_name": "test-project",
}
with caplog.at_level(logging.WARNING):
plugin.register(cfg)
# Should detect Comet
warning_messages = [record.message for record in caplog.records]
assert any("Comet" in msg for msg in warning_messages)
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabPluginPreModelLoad:
"""Tests for SwanLabPlugin.pre_model_load() runtime checks."""
def test_pre_model_load_disabled(self):
"""Test that pre_model_load does nothing when SwanLab is disabled."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = False
# Should not raise
plugin.pre_model_load(cfg)
def test_pre_model_load_import_error(self):
"""Test that missing swanlab package raises clear ImportError."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
with patch(
"builtins.__import__", side_effect=ImportError("No module named 'swanlab'")
):
with pytest.raises(ImportError) as exc_info:
plugin.pre_model_load(cfg)
error_msg = str(exc_info.value)
assert "SwanLab is not installed" in error_msg
assert "pip install swanlab" in error_msg
@patch("axolotl.utils.distributed.is_main_process")
@patch("axolotl.utils.distributed.get_world_size")
def test_pre_model_load_non_main_process_skips(
self, mock_get_world_size, mock_is_main_process
):
"""Test that non-main process skips SwanLab initialization."""
mock_get_world_size.return_value = 2
mock_is_main_process.return_value = False
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
with patch("swanlab.init") as mock_init:
plugin.pre_model_load(cfg)
# Should NOT call swanlab.init
mock_init.assert_not_called()
@patch("axolotl.utils.distributed.is_main_process")
@patch("axolotl.utils.distributed.get_world_size")
def test_pre_model_load_distributed_logging(
self, mock_get_world_size, mock_is_main_process, caplog
):
"""Test that distributed training logs world size info."""
mock_get_world_size.return_value = 4
mock_is_main_process.return_value = True
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
cfg.swanlab_project = "test-project"
cfg.swanlab_mode = "cloud"
with patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"):
with caplog.at_level(logging.INFO):
plugin.pre_model_load(cfg)
# Should log distributed training info
info_messages = [record.message for record in caplog.records]
assert any("world_size=4" in msg for msg in info_messages)
assert any("Only rank 0" in msg for msg in info_messages)
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabInitKwargs:
"""Tests for SwanLab initialization with direct parameter passing."""
def test_custom_branding_added_to_config(self):
"""Test that Axolotl custom branding is added to SwanLab config."""
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
cfg = DictDefault(
{
"use_swanlab": True,
"swanlab_project": "test-project",
}
)
init_kwargs = plugin._get_swanlab_init_kwargs(cfg)
# Verify custom branding is present
assert "config" in init_kwargs
assert init_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl"
def test_api_key_passed_directly(self):
"""Test that API key is passed directly to swanlab.init() instead of via env var."""
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
cfg = DictDefault(
{
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_api_key": "test-api-key-12345",
}
)
init_kwargs = plugin._get_swanlab_init_kwargs(cfg)
# Verify API key is in init_kwargs (not set as env var)
assert "api_key" in init_kwargs
assert init_kwargs["api_key"] == "test-api-key-12345"
def test_private_deployment_hosts_passed_directly(self):
"""Test that private deployment hosts are passed directly to swanlab.init()."""
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
cfg = DictDefault(
{
"use_swanlab": True,
"swanlab_project": "internal-project",
"swanlab_web_host": "https://swanlab.company.com",
"swanlab_api_host": "https://api-swanlab.company.com",
}
)
init_kwargs = plugin._get_swanlab_init_kwargs(cfg)
# Verify private deployment hosts are in init_kwargs
assert "web_host" in init_kwargs
assert init_kwargs["web_host"] == "https://swanlab.company.com"
assert "api_host" in init_kwargs
assert init_kwargs["api_host"] == "https://api-swanlab.company.com"
@patch("axolotl.utils.distributed.is_main_process")
def test_full_private_deployment_init(self, mock_is_main_process):
"""Test complete initialization with private deployment configuration."""
mock_is_main_process.return_value = True
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
cfg = DictDefault(
{
"use_swanlab": True,
"swanlab_project": "secure-project",
"swanlab_experiment_name": "experiment-001",
"swanlab_mode": "cloud",
"swanlab_api_key": "private-key-xyz",
"swanlab_web_host": "https://swanlab.internal.net",
"swanlab_api_host": "https://api.swanlab.internal.net",
"swanlab_workspace": "research-team",
}
)
with patch("swanlab.init") as mock_init:
plugin.pre_model_load(cfg)
# Verify swanlab.init was called with all parameters
mock_init.assert_called_once()
call_kwargs = mock_init.call_args[1]
assert call_kwargs["project"] == "secure-project"
assert call_kwargs["experiment_name"] == "experiment-001"
assert call_kwargs["mode"] == "cloud"
assert call_kwargs["api_key"] == "private-key-xyz"
assert call_kwargs["web_host"] == "https://swanlab.internal.net"
assert call_kwargs["api_host"] == "https://api.swanlab.internal.net"
assert call_kwargs["workspace"] == "research-team"
assert call_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl"
def test_env_vars_not_set_for_api_params(self):
"""Test that environment variables are NOT set for API parameters."""
import os
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
# Clear any existing env vars
for key in [
"SWANLAB_API_KEY",
"SWANLAB_WEB_HOST",
"SWANLAB_API_HOST",
"SWANLAB_MODE",
]:
os.environ.pop(key, None)
plugin = SwanLabPlugin()
cfg = DictDefault(
{
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_api_key": "test-key",
"swanlab_web_host": "https://test.com",
"swanlab_api_host": "https://api-test.com",
"swanlab_mode": "cloud",
}
)
with (
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("swanlab.init"),
):
plugin.pre_model_load(cfg)
# Verify env vars were NOT set (simplified approach)
# The old _setup_swanlab_env() method is removed, so these shouldn't be set
# Note: SwanLab itself might set these, but our plugin shouldn't
# We're just testing that our plugin doesn't call _setup_swanlab_env()
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestLarkNotificationIntegration:
"""Tests for Lark (Feishu) notification integration."""
def test_lark_callback_registration_with_webhook_only(self):
"""Test Lark callback registration with webhook URL only (no secret)."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
cfg.swanlab_project = "test-project"
cfg.swanlab_mode = "local"
cfg.swanlab_lark_webhook_url = (
"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook"
)
cfg.swanlab_lark_secret = None
with (
patch("swanlab.init"),
patch("swanlab.__version__", "0.3.0"),
patch("swanlab.register_callbacks") as mock_register,
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
# Mock LarkCallback import
with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback:
mock_lark_instance = MagicMock()
MockLarkCallback.return_value = mock_lark_instance
plugin.pre_model_load(cfg)
# Verify LarkCallback was instantiated with correct params
MockLarkCallback.assert_called_once_with(
webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook",
secret=None,
)
# Verify callback was registered
mock_register.assert_called_once_with([mock_lark_instance])
def test_lark_callback_registration_with_secret(self):
"""Test Lark callback registration with webhook URL and HMAC secret."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
cfg.swanlab_project = "test-project"
cfg.swanlab_mode = "local"
cfg.swanlab_lark_webhook_url = (
"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook"
)
cfg.swanlab_lark_secret = "test-hmac-secret"
with (
patch("swanlab.init"),
patch("swanlab.__version__", "0.3.0"),
patch("swanlab.register_callbacks") as mock_register,
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback:
mock_lark_instance = MagicMock()
MockLarkCallback.return_value = mock_lark_instance
plugin.pre_model_load(cfg)
# Verify LarkCallback was instantiated with secret
MockLarkCallback.assert_called_once_with(
webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook",
secret="test-hmac-secret",
)
mock_register.assert_called_once_with([mock_lark_instance])
def test_lark_callback_not_registered_without_webhook(self):
"""Test that Lark callback is NOT registered when webhook URL not provided."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
cfg.swanlab_project = "test-project"
cfg.swanlab_mode = "local"
cfg.swanlab_lark_webhook_url = None # No webhook
cfg.swanlab_lark_secret = None
with (
patch("swanlab.init"),
patch("swanlab.__version__", "0.3.0"),
patch("swanlab.register_callbacks") as mock_register,
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
plugin.pre_model_load(cfg)
# Verify register_callbacks was NOT called
mock_register.assert_not_called()
def test_lark_import_error_handled_gracefully(self, caplog):
"""Test that ImportError for Lark plugin is handled gracefully."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
cfg.swanlab_project = "test-project"
cfg.swanlab_mode = "local"
cfg.swanlab_lark_webhook_url = (
"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook"
)
cfg.swanlab_lark_secret = None
with (
patch("swanlab.init"),
patch("swanlab.__version__", "0.3.0"),
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
# Mock ImportError for LarkCallback
with patch(
"swanlab.plugin.notification.LarkCallback",
side_effect=ImportError(
"No module named 'swanlab.plugin.notification'"
),
):
with caplog.at_level(logging.WARNING):
plugin.pre_model_load(cfg)
# Should log warning about missing Lark plugin
warning_messages = [record.message for record in caplog.records]
assert any(
"Failed to import SwanLab Lark plugin" in msg
for msg in warning_messages
)
assert any("SwanLab >= 0.3.0" in msg for msg in warning_messages)
def test_lark_warning_for_missing_secret(self, caplog):
"""Test that warning is logged when Lark webhook has no HMAC secret."""
plugin = SwanLabPlugin()
cfg = MagicMock()
cfg.use_swanlab = True
cfg.swanlab_project = "test-project"
cfg.swanlab_mode = "local"
cfg.swanlab_lark_webhook_url = (
"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook"
)
cfg.swanlab_lark_secret = None # No secret
with (
patch("swanlab.init"),
patch("swanlab.__version__", "0.3.0"),
patch("swanlab.register_callbacks"),
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
with patch("swanlab.plugin.notification.LarkCallback"):
with caplog.at_level(logging.WARNING):
plugin.pre_model_load(cfg)
# Should log warning about missing secret
warning_messages = [record.message for record in caplog.records]
assert any(
"no secret configured" in msg.lower()
for msg in warning_messages
)
assert any("swanlab_lark_secret" in msg for msg in warning_messages)
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabPluginIntegration:
"""Integration tests for SwanLab plugin lifecycle."""
def test_full_lifecycle_valid_config(self):
"""Test full plugin lifecycle with valid configuration."""
plugin = SwanLabPlugin()
# Register
cfg_dict = {
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_mode": "local",
}
plugin.register(cfg_dict)
# Pre-model load (mock SwanLab)
cfg_obj = MagicMock()
cfg_obj.use_swanlab = True
cfg_obj.swanlab_project = "test-project"
cfg_obj.swanlab_mode = "local"
cfg_obj.swanlab_lark_webhook_url = None # No Lark
with (
patch("swanlab.init") as mock_init,
patch("swanlab.__version__", "0.3.0"),
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
plugin.pre_model_load(cfg_obj)
# Should call swanlab.init
mock_init.assert_called_once()
def test_lifecycle_with_multi_logger_warning(self, caplog):
"""Test lifecycle with multi-logger warning."""
plugin = SwanLabPlugin()
cfg_dict = {
"use_swanlab": True,
"swanlab_project": "test-project",
"use_wandb": True,
}
with caplog.at_level(logging.WARNING):
plugin.register(cfg_dict)
# Should have multi-logger warning
warning_messages = [record.message for record in caplog.records]
assert any("Multiple logging tools" in msg for msg in warning_messages)
def test_lifecycle_invalid_config_fails_early(self):
"""Test that invalid config fails at register stage."""
plugin = SwanLabPlugin()
cfg_dict = {
"use_swanlab": True,
# Missing swanlab_project
}
# Should fail at register, not pre_model_load
with pytest.raises(ValueError):
plugin.register(cfg_dict)
def test_full_lifecycle_with_lark_notifications(self):
"""Test full lifecycle including Lark notification registration."""
plugin = SwanLabPlugin()
# Register
cfg_dict = {
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_mode": "cloud",
}
plugin.register(cfg_dict)
# Pre-model load with Lark config
cfg_obj = MagicMock()
cfg_obj.use_swanlab = True
cfg_obj.swanlab_project = "test-project"
cfg_obj.swanlab_mode = "cloud"
cfg_obj.swanlab_lark_webhook_url = (
"https://open.feishu.cn/open-apis/bot/v2/hook/test"
)
cfg_obj.swanlab_lark_secret = "secret123"
with (
patch("swanlab.init"),
patch("swanlab.__version__", "0.3.0"),
patch("swanlab.register_callbacks") as mock_register,
patch("axolotl.utils.distributed.is_main_process", return_value=True),
patch("axolotl.utils.distributed.get_world_size", return_value=1),
):
with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback:
mock_lark_instance = MagicMock()
MockLarkCallback.return_value = mock_lark_instance
plugin.pre_model_load(cfg_obj)
# Verify both SwanLab init AND Lark callback registration
MockLarkCallback.assert_called_once()
mock_register.assert_called_once_with([mock_lark_instance])
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestCompletionLogger:
"""Tests for CompletionLogger utility class."""
def test_completion_logger_initialization(self):
"""Test CompletionLogger initializes with correct maxlen."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=64)
assert logger.maxlen == 64
assert len(logger) == 0
def test_add_dpo_completion(self):
"""Test adding DPO completions to buffer."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_dpo_completion(
step=0,
prompt="What is AI?",
chosen="Artificial Intelligence is...",
rejected="AI means...",
reward_diff=0.5,
)
assert len(logger) == 1
entry = logger.data[0]
assert entry["step"] == 0
assert entry["prompt"] == "What is AI?"
assert entry["chosen"] == "Artificial Intelligence is..."
assert entry["rejected"] == "AI means..."
assert entry["reward_diff"] == 0.5
def test_add_kto_completion(self):
"""Test adding KTO completions to buffer."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_kto_completion(
step=1,
prompt="Explain quantum physics",
completion="Quantum physics is...",
label=True,
reward=0.8,
)
assert len(logger) == 1
entry = logger.data[0]
assert entry["step"] == 1
assert entry["prompt"] == "Explain quantum physics"
assert entry["completion"] == "Quantum physics is..."
assert entry["label"] == "desirable"
assert entry["reward"] == 0.8
def test_add_orpo_completion(self):
"""Test adding ORPO completions to buffer."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_orpo_completion(
step=2,
prompt="Write a poem",
chosen="Roses are red...",
rejected="Violets are blue...",
log_odds_ratio=1.2,
)
assert len(logger) == 1
entry = logger.data[0]
assert entry["step"] == 2
assert entry["chosen"] == "Roses are red..."
assert entry["rejected"] == "Violets are blue..."
assert entry["log_odds_ratio"] == 1.2
def test_add_grpo_completion(self):
"""Test adding GRPO completions to buffer."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_grpo_completion(
step=3,
prompt="Solve this problem",
completion="The answer is 42",
reward=0.9,
advantage=0.3,
)
assert len(logger) == 1
entry = logger.data[0]
assert entry["step"] == 3
assert entry["completion"] == "The answer is 42"
assert entry["reward"] == 0.9
assert entry["advantage"] == 0.3
def test_memory_bounded_buffer(self):
"""Test that buffer respects maxlen and drops oldest entries."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=3)
# Add 5 completions
for i in range(5):
logger.add_dpo_completion(
step=i,
prompt=f"Prompt {i}",
chosen=f"Chosen {i}",
rejected=f"Rejected {i}",
)
# Should only keep last 3
assert len(logger) == 3
assert logger.data[0]["step"] == 2 # Oldest kept
assert logger.data[1]["step"] == 3
assert logger.data[2]["step"] == 4 # Newest
def test_log_to_swanlab_when_not_initialized(self):
"""Test logging gracefully fails when SwanLab not initialized."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_dpo_completion(
step=0,
prompt="Test",
chosen="A",
rejected="B",
)
with patch("swanlab.get_run", return_value=None):
result = logger.log_to_swanlab()
assert result is False # Should fail gracefully
def test_log_to_swanlab_success(self):
"""Test successful logging to SwanLab."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_dpo_completion(
step=0,
prompt="Test prompt",
chosen="Chosen response",
rejected="Rejected response",
reward_diff=0.5,
)
with (
patch("swanlab.get_run") as mock_get_run,
patch("swanlab.log") as mock_log,
patch("swanlab.echarts.Table") as MockTable,
):
mock_get_run.return_value = MagicMock() # SwanLab initialized
mock_table_instance = MagicMock()
MockTable.return_value = mock_table_instance
result = logger.log_to_swanlab(table_name="test_table")
assert result is True
mock_log.assert_called_once()
mock_table_instance.add.assert_called_once()
def test_clear_buffer(self):
"""Test clearing the completion buffer."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=10)
logger.add_dpo_completion(
step=0,
prompt="Test",
chosen="A",
rejected="B",
)
assert len(logger) == 1
logger.clear()
assert len(logger) == 0
def test_repr(self):
"""Test string representation."""
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
logger = CompletionLogger(maxlen=128)
logger.add_dpo_completion(
step=0,
prompt="Test",
chosen="A",
rejected="B",
)
repr_str = repr(logger)
assert "CompletionLogger" in repr_str
assert "maxlen=128" in repr_str
assert "buffered=1/128" in repr_str
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabRLHFCompletionCallback:
"""Tests for SwanLabRLHFCompletionCallback."""
def test_callback_initialization(self):
"""Test callback initializes with correct parameters."""
from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback
callback = SwanLabRLHFCompletionCallback(
log_interval=50,
max_completions=64,
table_name="custom_table",
)
assert callback.log_interval == 50
assert callback.logger.maxlen == 64
assert callback.table_name == "custom_table"
assert callback.trainer_type is None
def test_trainer_type_detection_dpo(self):
"""Test DPO trainer type is detected correctly."""
from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback
callback = SwanLabRLHFCompletionCallback()
# Mock trainer with DPO in name
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = "AxolotlDPOTrainer"
callback.on_init_end(
args=MagicMock(),
state=MagicMock(),
control=MagicMock(),
trainer=mock_trainer,
)
assert callback.trainer_type == "dpo"
def test_trainer_type_detection_kto(self):
"""Test KTO trainer type is detected correctly."""
from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback
callback = SwanLabRLHFCompletionCallback()
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = "AxolotlKTOTrainer"
callback.on_init_end(
args=MagicMock(),
state=MagicMock(),
control=MagicMock(),
trainer=mock_trainer,
)
assert callback.trainer_type == "kto"
def test_on_train_end_logs_completions(self):
"""Test that completions are logged at end of training."""
from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback
callback = SwanLabRLHFCompletionCallback()
callback.trainer_type = "dpo"
# Add some completions to buffer
callback.logger.add_dpo_completion(
step=0,
prompt="Test",
chosen="A",
rejected="B",
)
with patch.object(callback.logger, "log_to_swanlab") as mock_log:
callback.on_train_end(
args=MagicMock(),
state=MagicMock(global_step=100),
control=MagicMock(),
)
# Should log remaining completions
mock_log.assert_called_once()
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabPluginCompletionIntegration:
"""Integration tests for completion logging in SwanLabPlugin."""
def test_completion_callback_registered_for_dpo_trainer(self):
"""Test that completion callback is registered for DPO trainer."""
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
plugin.swanlab_initialized = True # Simulate SwanLab initialized
cfg = {
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_log_completions": True,
"swanlab_completion_log_interval": 50,
"swanlab_completion_max_buffer": 64,
}
cfg_obj = DictDefault(cfg)
# Mock DPO trainer
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = "AxolotlDPOTrainer"
mock_trainer.state = MagicMock(max_steps=1000)
mock_trainer.args = MagicMock(
num_train_epochs=3,
train_batch_size=4,
gradient_accumulation_steps=2,
)
with patch("swanlab.config.update"):
plugin.post_trainer_create(cfg_obj, mock_trainer)
# Verify callback was added
mock_trainer.add_callback.assert_called_once()
callback = mock_trainer.add_callback.call_args[0][0]
assert callback.__class__.__name__ == "SwanLabRLHFCompletionCallback"
assert callback.log_interval == 50
assert callback.logger.maxlen == 64
def test_completion_callback_not_registered_for_non_rlhf_trainer(self):
"""Test that completion callback is NOT registered for non-RLHF trainers."""
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
plugin.swanlab_initialized = True
cfg = {
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_log_completions": True,
}
cfg_obj = DictDefault(cfg)
# Mock regular SFT trainer (not RLHF)
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = "AxolotlTrainer" # Not RLHF
mock_trainer.state = MagicMock(max_steps=1000)
mock_trainer.args = MagicMock()
with patch("swanlab.config.update"):
plugin.post_trainer_create(cfg_obj, mock_trainer)
# Callback should NOT be added for non-RLHF trainer
mock_trainer.add_callback.assert_not_called()
def test_completion_callback_not_registered_when_disabled(self):
"""Test that completion callback is not registered when disabled in config."""
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
from axolotl.utils.dict import DictDefault
plugin = SwanLabPlugin()
plugin.swanlab_initialized = True
cfg = {
"use_swanlab": True,
"swanlab_project": "test-project",
"swanlab_log_completions": False, # Disabled
}
cfg_obj = DictDefault(cfg)
# Mock DPO trainer
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = "AxolotlDPOTrainer"
mock_trainer.state = MagicMock(max_steps=1000)
mock_trainer.args = MagicMock()
with patch("swanlab.config.update"):
plugin.post_trainer_create(cfg_obj, mock_trainer)
# Callback should NOT be added when disabled
mock_trainer.add_callback.assert_not_called()
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
class TestSwanLabProfiling:
"""Tests for SwanLab profiling utilities."""
def test_profiling_context_logs_duration(self):
"""Test that profiling context logs execution duration."""
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
# Mock trainer with SwanLab enabled
mock_trainer = MagicMock()
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer"
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
mock_get_run.return_value = MagicMock() # SwanLab initialized
with swanlab_profiling_context(mock_trainer, "test_function"):
time.sleep(0.01) # Simulate work
# Verify log was called with correct metric name
mock_log.assert_called_once()
logged_data = mock_log.call_args[0][0]
assert "profiling/Time taken: TestTrainer.test_function" in logged_data
# Duration should be > 0.01 seconds
assert (
logged_data["profiling/Time taken: TestTrainer.test_function"] >= 0.01
)
def test_profiling_context_skips_when_swanlab_disabled(self):
"""Test that profiling is skipped when SwanLab is disabled."""
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock()
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=False) # Disabled
with patch("swanlab.log") as mock_log:
with swanlab_profiling_context(mock_trainer, "test_function"):
time.sleep(0.01)
# Should NOT log when disabled
mock_log.assert_not_called()
def test_profiling_context_skips_when_swanlab_not_initialized(self):
"""Test that profiling is skipped when SwanLab not initialized."""
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock()
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
with (
patch("swanlab.get_run", return_value=None),
patch("swanlab.log") as mock_log,
):
with swanlab_profiling_context(mock_trainer, "test_function"):
time.sleep(0.01)
# Should NOT log when not initialized
mock_log.assert_not_called()
def test_profiling_decorator(self):
"""Test swanlab_profile decorator."""
from axolotl.integrations.swanlab.profiling import swanlab_profile
class MockTrainer:
def __init__(self):
self.cfg = MagicMock(use_swanlab=True)
@swanlab_profile
def expensive_method(self, x):
time.sleep(0.01)
return x * 2
trainer = MockTrainer()
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
mock_get_run.return_value = MagicMock()
result = trainer.expensive_method(5)
# Verify method still works correctly
assert result == 10
# Verify profiling was logged
mock_log.assert_called_once()
logged_data = mock_log.call_args[0][0]
assert "profiling/Time taken: MockTrainer.expensive_method" in logged_data
def test_profiling_config(self):
"""Test ProfilingConfig class."""
from axolotl.integrations.swanlab.profiling import ProfilingConfig
config = ProfilingConfig(
enabled=True,
min_duration_ms=1.0,
log_interval=5,
)
# Test enabled check
assert config.enabled is True
# Test minimum duration filtering
assert config.should_log("func1", 0.0001) is False # 0.1ms < 1.0ms threshold
assert config.should_log("func2", 0.002) is True # 2.0ms > 1.0ms threshold
# Test log interval
assert config.should_log("func3", 0.002) is True # 1st call
assert config.should_log("func3", 0.002) is False # 2nd call
assert config.should_log("func3", 0.002) is False # 3rd call
assert config.should_log("func3", 0.002) is False # 4th call
assert config.should_log("func3", 0.002) is True # 5th call (interval=5)
def test_profiling_config_when_disabled(self):
"""Test ProfilingConfig when disabled."""
from axolotl.integrations.swanlab.profiling import ProfilingConfig
config = ProfilingConfig(enabled=False)
# Should never log when disabled
assert config.should_log("func1", 100.0) is False
def test_profiling_context_advanced(self):
"""Test advanced profiling context with custom config."""
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profiling_context_advanced,
)
mock_trainer = MagicMock()
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer"
# Config that filters out very fast operations
config = ProfilingConfig(min_duration_ms=10.0) # 10ms minimum
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
mock_get_run.return_value = MagicMock()
# Fast operation (< 10ms) - should NOT log
with swanlab_profiling_context_advanced(mock_trainer, "fast_op", config):
time.sleep(0.001) # 1ms
mock_log.assert_not_called()
# Slow operation (> 10ms) - should log
with swanlab_profiling_context_advanced(mock_trainer, "slow_op", config):
time.sleep(0.015) # 15ms
mock_log.assert_called_once()
def test_profiling_with_exception(self):
"""Test that profiling still logs even when exception occurs."""
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock()
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer"
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
mock_get_run.return_value = MagicMock()
try:
with swanlab_profiling_context(mock_trainer, "error_function"):
time.sleep(0.01)
raise ValueError("Test error")
except ValueError:
pass # Expected
# Should still log duration even with exception
mock_log.assert_called_once()