# Copyright 2024 Axolotl AI. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Unit tests for SwanLab Integration Plugin. Tests conflict detection, configuration validation, and multi-logger warnings. """ import logging import os import time from unittest.mock import MagicMock, patch import pytest from pydantic import ValidationError from transformers.utils.import_utils import _is_package_available from axolotl.integrations.swanlab.args import SwanLabConfig from axolotl.integrations.swanlab.plugins import SwanLabPlugin SWANLAB_INSTALLED = _is_package_available("swanlab") @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabConfigValidators: """Tests for Pydantic field validators in SwanLabConfig.""" def test_valid_swanlab_mode_cloud(self): """Test that 'cloud' mode is valid.""" config = SwanLabConfig(swanlab_mode="cloud") assert config.swanlab_mode == "cloud" def test_valid_swanlab_mode_local(self): """Test that 'local' mode is valid.""" config = SwanLabConfig(swanlab_mode="local") assert config.swanlab_mode == "local" def test_valid_swanlab_mode_offline(self): """Test that 'offline' mode is valid.""" config = SwanLabConfig(swanlab_mode="offline") assert config.swanlab_mode == "offline" def test_valid_swanlab_mode_disabled(self): """Test that 'disabled' mode is valid.""" config = SwanLabConfig(swanlab_mode="disabled") assert config.swanlab_mode == "disabled" def test_invalid_swanlab_mode(self): """Test that invalid mode raises ValueError.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(swanlab_mode="invalid") error_msg = str(exc_info.value) assert "Invalid swanlab_mode" in error_msg assert "cloud" in error_msg assert "local" in error_msg assert "offline" in error_msg assert "disabled" in error_msg def test_swanlab_mode_none_allowed(self): """Test that None mode is allowed (will use default).""" config = SwanLabConfig(swanlab_mode=None) assert config.swanlab_mode is None def test_valid_swanlab_project(self): """Test that valid project name is accepted.""" config = SwanLabConfig(swanlab_project="my-project") assert config.swanlab_project == "my-project" def test_swanlab_project_none_allowed(self): """Test that None project is allowed.""" config = SwanLabConfig(swanlab_project=None) assert config.swanlab_project is None def test_empty_swanlab_project_rejected(self): """Test that empty string project name is rejected.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(swanlab_project="") error_msg = str(exc_info.value) assert "cannot be an empty string" in error_msg def test_whitespace_only_project_rejected(self): """Test that whitespace-only project name is rejected.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(swanlab_project=" ") error_msg = str(exc_info.value) assert "cannot be an empty string" in error_msg def test_use_swanlab_true_requires_project(self): """Test that use_swanlab=True requires swanlab_project.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(use_swanlab=True, swanlab_project=None) error_msg = str(exc_info.value) assert "swanlab_project" in error_msg.lower() assert "not set" in error_msg.lower() def test_use_swanlab_true_with_project_valid(self): """Test that use_swanlab=True with project is valid.""" config = SwanLabConfig(use_swanlab=True, swanlab_project="my-project") assert config.use_swanlab is True assert config.swanlab_project == "my-project" def test_use_swanlab_false_no_project_valid(self): """Test that use_swanlab=False without project is valid.""" config = SwanLabConfig(use_swanlab=False, swanlab_project=None) assert config.use_swanlab is False assert config.swanlab_project is None def test_use_swanlab_none_no_project_valid(self): """Test that use_swanlab=None without project is valid.""" config = SwanLabConfig(use_swanlab=None, swanlab_project=None) assert config.use_swanlab is None assert config.swanlab_project is None @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginRegister: """Tests for SwanLabPlugin.register() conflict detection.""" def test_register_without_use_swanlab(self): """Test that register works when SwanLab is not enabled.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": False} # Should not raise plugin.register(cfg) def test_register_use_swanlab_missing_project(self): """Test that use_swanlab=True without project raises ValueError.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": True} with pytest.raises(ValueError) as exc_info: plugin.register(cfg) error_msg = str(exc_info.value) assert "swanlab_project" in error_msg assert "not set" in error_msg assert "Solutions" in error_msg def test_register_use_swanlab_with_project_valid(self): """Test that use_swanlab=True with project is valid.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": True, "swanlab_project": "my-project"} # Should not raise plugin.register(cfg) def test_register_invalid_mode(self): """Test that invalid swanlab_mode raises ValueError.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "swanlab_mode": "invalid-mode", } with pytest.raises(ValueError) as exc_info: plugin.register(cfg) error_msg = str(exc_info.value) assert "Invalid swanlab_mode" in error_msg assert "cloud" in error_msg assert "local" in error_msg def test_register_valid_modes(self): """Test that all valid modes are accepted.""" plugin = SwanLabPlugin() valid_modes = ["cloud", "local", "offline", "disabled"] for mode in valid_modes: cfg = { "use_swanlab": True, "swanlab_project": "my-project", "swanlab_mode": mode, } # Should not raise plugin.register(cfg) def test_register_auto_enable_swanlab(self): """Test that providing swanlab_project auto-enables use_swanlab.""" plugin = SwanLabPlugin() cfg = {"swanlab_project": "my-project"} plugin.register(cfg) assert cfg["use_swanlab"] is True def test_register_cloud_mode_without_api_key_warns(self, caplog): """Test that cloud mode without API key logs warning.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "swanlab_mode": "cloud", } # Clear environment variable to ensure it's not set with patch.dict(os.environ, {}, clear=True): with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should log warning about missing API key warning_messages = [record.message for record in caplog.records] assert any("API key" in msg for msg in warning_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestMultiLoggerDetection: """Tests for multi-logger conflict detection.""" def test_single_logger_no_warning(self, caplog): """Test that single logger doesn't trigger warning.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": True, "swanlab_project": "my-project"} with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should not log multi-logger warning warning_messages = [record.message for record in caplog.records] assert not any("Multiple logging tools" in msg for msg in warning_messages) def test_two_loggers_warning(self, caplog): """Test that two loggers trigger warning.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "use_wandb": True, } with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should log multi-logger warning warning_messages = [record.message for record in caplog.records] assert any("Multiple logging tools" in msg for msg in warning_messages) assert any("SwanLab" in msg and "WandB" in msg for msg in warning_messages) def test_three_loggers_error(self, caplog): """Test that three loggers trigger error-level warning.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "use_wandb": True, "use_mlflow": True, } with caplog.at_level(logging.ERROR): plugin.register(cfg) # Should log error-level warning error_messages = [ record.message for record in caplog.records if record.levelno >= logging.ERROR ] assert any("logging tools enabled" in msg for msg in error_messages) def test_multi_logger_with_comet(self, caplog): """Test that Comet is detected in multi-logger scenario.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "comet_api_key": "test-key", } with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should detect Comet warning_messages = [record.message for record in caplog.records] assert any("Comet" in msg for msg in warning_messages) def test_multi_logger_with_comet_project(self, caplog): """Test that Comet is detected via comet_project_name.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "comet_project_name": "test-project", } with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should detect Comet warning_messages = [record.message for record in caplog.records] assert any("Comet" in msg for msg in warning_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginPreModelLoad: """Tests for SwanLabPlugin.pre_model_load() runtime checks.""" def test_pre_model_load_disabled(self): """Test that pre_model_load does nothing when SwanLab is disabled.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = False # Should not raise plugin.pre_model_load(cfg) def test_pre_model_load_import_error(self): """Test that missing swanlab package raises clear ImportError.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True with patch( "builtins.__import__", side_effect=ImportError("No module named 'swanlab'") ): with pytest.raises(ImportError) as exc_info: plugin.pre_model_load(cfg) error_msg = str(exc_info.value) assert "SwanLab is not installed" in error_msg assert "pip install swanlab" in error_msg @patch("axolotl.utils.distributed.is_main_process") @patch("axolotl.utils.distributed.get_world_size") def test_pre_model_load_non_main_process_skips( self, mock_get_world_size, mock_is_main_process ): """Test that non-main process skips SwanLab initialization.""" mock_get_world_size.return_value = 2 mock_is_main_process.return_value = False plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True with patch("swanlab.init") as mock_init: plugin.pre_model_load(cfg) # Should NOT call swanlab.init mock_init.assert_not_called() @patch("axolotl.utils.distributed.is_main_process") @patch("axolotl.utils.distributed.get_world_size") def test_pre_model_load_distributed_logging( self, mock_get_world_size, mock_is_main_process, caplog ): """Test that distributed training logs world size info.""" mock_get_world_size.return_value = 4 mock_is_main_process.return_value = True plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "cloud" with patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"): with caplog.at_level(logging.INFO): plugin.pre_model_load(cfg) # Should log distributed training info info_messages = [record.message for record in caplog.records] assert any("world_size=4" in msg for msg in info_messages) assert any("Only rank 0" in msg for msg in info_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabInitKwargs: """Tests for SwanLab initialization with direct parameter passing.""" def test_custom_branding_added_to_config(self): """Test that Axolotl custom branding is added to SwanLab config.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "test-project", } ) init_kwargs = plugin._get_swanlab_init_kwargs(cfg) # Verify custom branding is present assert "config" in init_kwargs assert init_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl" def test_api_key_passed_directly(self): """Test that API key is passed directly to swanlab.init() instead of via env var.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_api_key": "test-api-key-12345", } ) init_kwargs = plugin._get_swanlab_init_kwargs(cfg) # Verify API key is in init_kwargs (not set as env var) assert "api_key" in init_kwargs assert init_kwargs["api_key"] == "test-api-key-12345" def test_private_deployment_hosts_passed_directly(self): """Test that private deployment hosts are passed directly to swanlab.init().""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "internal-project", "swanlab_web_host": "https://swanlab.company.com", "swanlab_api_host": "https://api-swanlab.company.com", } ) init_kwargs = plugin._get_swanlab_init_kwargs(cfg) # Verify private deployment hosts are in init_kwargs assert "web_host" in init_kwargs assert init_kwargs["web_host"] == "https://swanlab.company.com" assert "api_host" in init_kwargs assert init_kwargs["api_host"] == "https://api-swanlab.company.com" @patch("axolotl.utils.distributed.is_main_process") def test_full_private_deployment_init(self, mock_is_main_process): """Test complete initialization with private deployment configuration.""" mock_is_main_process.return_value = True from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "secure-project", "swanlab_experiment_name": "experiment-001", "swanlab_mode": "cloud", "swanlab_api_key": "private-key-xyz", "swanlab_web_host": "https://swanlab.internal.net", "swanlab_api_host": "https://api.swanlab.internal.net", "swanlab_workspace": "research-team", } ) with patch("swanlab.init") as mock_init: plugin.pre_model_load(cfg) # Verify swanlab.init was called with all parameters mock_init.assert_called_once() call_kwargs = mock_init.call_args[1] assert call_kwargs["project"] == "secure-project" assert call_kwargs["experiment_name"] == "experiment-001" assert call_kwargs["mode"] == "cloud" assert call_kwargs["api_key"] == "private-key-xyz" assert call_kwargs["web_host"] == "https://swanlab.internal.net" assert call_kwargs["api_host"] == "https://api.swanlab.internal.net" assert call_kwargs["workspace"] == "research-team" assert call_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl" def test_env_vars_not_set_for_api_params(self): """Test that environment variables are NOT set for API parameters.""" import os from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault # Clear any existing env vars for key in [ "SWANLAB_API_KEY", "SWANLAB_WEB_HOST", "SWANLAB_API_HOST", "SWANLAB_MODE", ]: os.environ.pop(key, None) plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_api_key": "test-key", "swanlab_web_host": "https://test.com", "swanlab_api_host": "https://api-test.com", "swanlab_mode": "cloud", } ) with ( patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("swanlab.init"), ): plugin.pre_model_load(cfg) # Verify env vars were NOT set (simplified approach) # The old _setup_swanlab_env() method is removed, so these shouldn't be set # Note: SwanLab itself might set these, but our plugin shouldn't # We're just testing that our plugin doesn't call _setup_swanlab_env() @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestLarkNotificationIntegration: """Tests for Lark (Feishu) notification integration.""" def test_lark_callback_registration_with_webhook_only(self): """Test Lark callback registration with webhook URL only (no secret).""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = None with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): # Mock LarkCallback import with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: mock_lark_instance = MagicMock() MockLarkCallback.return_value = mock_lark_instance plugin.pre_model_load(cfg) # Verify LarkCallback was instantiated with correct params MockLarkCallback.assert_called_once_with( webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook", secret=None, ) # Verify callback was registered mock_register.assert_called_once_with([mock_lark_instance]) def test_lark_callback_registration_with_secret(self): """Test Lark callback registration with webhook URL and HMAC secret.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = "test-hmac-secret" with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: mock_lark_instance = MagicMock() MockLarkCallback.return_value = mock_lark_instance plugin.pre_model_load(cfg) # Verify LarkCallback was instantiated with secret MockLarkCallback.assert_called_once_with( webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook", secret="test-hmac-secret", ) mock_register.assert_called_once_with([mock_lark_instance]) def test_lark_callback_not_registered_without_webhook(self): """Test that Lark callback is NOT registered when webhook URL not provided.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = None # No webhook cfg.swanlab_lark_secret = None with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): plugin.pre_model_load(cfg) # Verify register_callbacks was NOT called mock_register.assert_not_called() def test_lark_import_error_handled_gracefully(self, caplog): """Test that ImportError for Lark plugin is handled gracefully.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = None with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): # Mock ImportError for LarkCallback with patch( "swanlab.plugin.notification.LarkCallback", side_effect=ImportError( "No module named 'swanlab.plugin.notification'" ), ): with caplog.at_level(logging.WARNING): plugin.pre_model_load(cfg) # Should log warning about missing Lark plugin warning_messages = [record.message for record in caplog.records] assert any( "Failed to import SwanLab Lark plugin" in msg for msg in warning_messages ) assert any("SwanLab >= 0.3.0" in msg for msg in warning_messages) def test_lark_warning_for_missing_secret(self, caplog): """Test that warning is logged when Lark webhook has no HMAC secret.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = None # No secret with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks"), patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): with patch("swanlab.plugin.notification.LarkCallback"): with caplog.at_level(logging.WARNING): plugin.pre_model_load(cfg) # Should log warning about missing secret warning_messages = [record.message for record in caplog.records] assert any( "no secret configured" in msg.lower() for msg in warning_messages ) assert any("swanlab_lark_secret" in msg for msg in warning_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginIntegration: """Integration tests for SwanLab plugin lifecycle.""" def test_full_lifecycle_valid_config(self): """Test full plugin lifecycle with valid configuration.""" plugin = SwanLabPlugin() # Register cfg_dict = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_mode": "local", } plugin.register(cfg_dict) # Pre-model load (mock SwanLab) cfg_obj = MagicMock() cfg_obj.use_swanlab = True cfg_obj.swanlab_project = "test-project" cfg_obj.swanlab_mode = "local" cfg_obj.swanlab_lark_webhook_url = None # No Lark with ( patch("swanlab.init") as mock_init, patch("swanlab.__version__", "0.3.0"), patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): plugin.pre_model_load(cfg_obj) # Should call swanlab.init mock_init.assert_called_once() def test_lifecycle_with_multi_logger_warning(self, caplog): """Test lifecycle with multi-logger warning.""" plugin = SwanLabPlugin() cfg_dict = { "use_swanlab": True, "swanlab_project": "test-project", "use_wandb": True, } with caplog.at_level(logging.WARNING): plugin.register(cfg_dict) # Should have multi-logger warning warning_messages = [record.message for record in caplog.records] assert any("Multiple logging tools" in msg for msg in warning_messages) def test_lifecycle_invalid_config_fails_early(self): """Test that invalid config fails at register stage.""" plugin = SwanLabPlugin() cfg_dict = { "use_swanlab": True, # Missing swanlab_project } # Should fail at register, not pre_model_load with pytest.raises(ValueError): plugin.register(cfg_dict) def test_full_lifecycle_with_lark_notifications(self): """Test full lifecycle including Lark notification registration.""" plugin = SwanLabPlugin() # Register cfg_dict = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_mode": "cloud", } plugin.register(cfg_dict) # Pre-model load with Lark config cfg_obj = MagicMock() cfg_obj.use_swanlab = True cfg_obj.swanlab_project = "test-project" cfg_obj.swanlab_mode = "cloud" cfg_obj.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test" ) cfg_obj.swanlab_lark_secret = "secret123" with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: mock_lark_instance = MagicMock() MockLarkCallback.return_value = mock_lark_instance plugin.pre_model_load(cfg_obj) # Verify both SwanLab init AND Lark callback registration MockLarkCallback.assert_called_once() mock_register.assert_called_once_with([mock_lark_instance]) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestCompletionLogger: """Tests for CompletionLogger utility class.""" def test_completion_logger_initialization(self): """Test CompletionLogger initializes with correct maxlen.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=64) assert logger.maxlen == 64 assert len(logger) == 0 def test_add_dpo_completion(self): """Test adding DPO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="What is AI?", chosen="Artificial Intelligence is...", rejected="AI means...", reward_diff=0.5, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 0 assert entry["prompt"] == "What is AI?" assert entry["chosen"] == "Artificial Intelligence is..." assert entry["rejected"] == "AI means..." assert entry["reward_diff"] == 0.5 def test_add_kto_completion(self): """Test adding KTO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_kto_completion( step=1, prompt="Explain quantum physics", completion="Quantum physics is...", label=True, reward=0.8, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 1 assert entry["prompt"] == "Explain quantum physics" assert entry["completion"] == "Quantum physics is..." assert entry["label"] == "desirable" assert entry["reward"] == 0.8 def test_add_orpo_completion(self): """Test adding ORPO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_orpo_completion( step=2, prompt="Write a poem", chosen="Roses are red...", rejected="Violets are blue...", log_odds_ratio=1.2, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 2 assert entry["chosen"] == "Roses are red..." assert entry["rejected"] == "Violets are blue..." assert entry["log_odds_ratio"] == 1.2 def test_add_grpo_completion(self): """Test adding GRPO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_grpo_completion( step=3, prompt="Solve this problem", completion="The answer is 42", reward=0.9, advantage=0.3, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 3 assert entry["completion"] == "The answer is 42" assert entry["reward"] == 0.9 assert entry["advantage"] == 0.3 def test_memory_bounded_buffer(self): """Test that buffer respects maxlen and drops oldest entries.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=3) # Add 5 completions for i in range(5): logger.add_dpo_completion( step=i, prompt=f"Prompt {i}", chosen=f"Chosen {i}", rejected=f"Rejected {i}", ) # Should only keep last 3 assert len(logger) == 3 assert logger.data[0]["step"] == 2 # Oldest kept assert logger.data[1]["step"] == 3 assert logger.data[2]["step"] == 4 # Newest def test_log_to_swanlab_when_not_initialized(self): """Test logging gracefully fails when SwanLab not initialized.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) with patch("swanlab.get_run", return_value=None): result = logger.log_to_swanlab() assert result is False # Should fail gracefully def test_log_to_swanlab_success(self): """Test successful logging to SwanLab.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="Test prompt", chosen="Chosen response", rejected="Rejected response", reward_diff=0.5, ) with ( patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log, patch("swanlab.echarts.Table") as MockTable, ): mock_get_run.return_value = MagicMock() # SwanLab initialized mock_table_instance = MagicMock() MockTable.return_value = mock_table_instance result = logger.log_to_swanlab(table_name="test_table") assert result is True mock_log.assert_called_once() mock_table_instance.add.assert_called_once() def test_clear_buffer(self): """Test clearing the completion buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) assert len(logger) == 1 logger.clear() assert len(logger) == 0 def test_repr(self): """Test string representation.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=128) logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) repr_str = repr(logger) assert "CompletionLogger" in repr_str assert "maxlen=128" in repr_str assert "buffered=1/128" in repr_str @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabRLHFCompletionCallback: """Tests for SwanLabRLHFCompletionCallback.""" def test_callback_initialization(self): """Test callback initializes with correct parameters.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback( log_interval=50, max_completions=64, table_name="custom_table", ) assert callback.log_interval == 50 assert callback.logger.maxlen == 64 assert callback.table_name == "custom_table" assert callback.trainer_type is None def test_trainer_type_detection_dpo(self): """Test DPO trainer type is detected correctly.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback() # Mock trainer with DPO in name mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" callback.on_init_end( args=MagicMock(), state=MagicMock(), control=MagicMock(), trainer=mock_trainer, ) assert callback.trainer_type == "dpo" def test_trainer_type_detection_kto(self): """Test KTO trainer type is detected correctly.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback() mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlKTOTrainer" callback.on_init_end( args=MagicMock(), state=MagicMock(), control=MagicMock(), trainer=mock_trainer, ) assert callback.trainer_type == "kto" def test_on_train_end_logs_completions(self): """Test that completions are logged at end of training.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback() callback.trainer_type = "dpo" # Add some completions to buffer callback.logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) with patch.object(callback.logger, "log_to_swanlab") as mock_log: callback.on_train_end( args=MagicMock(), state=MagicMock(global_step=100), control=MagicMock(), ) # Should log remaining completions mock_log.assert_called_once() @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginCompletionIntegration: """Integration tests for completion logging in SwanLabPlugin.""" def test_completion_callback_registered_for_dpo_trainer(self): """Test that completion callback is registered for DPO trainer.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() plugin.swanlab_initialized = True # Simulate SwanLab initialized cfg = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_log_completions": True, "swanlab_completion_log_interval": 50, "swanlab_completion_max_buffer": 64, } cfg_obj = DictDefault(cfg) # Mock DPO trainer mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" mock_trainer.state = MagicMock(max_steps=1000) mock_trainer.args = MagicMock( num_train_epochs=3, train_batch_size=4, gradient_accumulation_steps=2, ) with patch("swanlab.config.update"): plugin.post_trainer_create(cfg_obj, mock_trainer) # Verify callback was added mock_trainer.add_callback.assert_called_once() callback = mock_trainer.add_callback.call_args[0][0] assert callback.__class__.__name__ == "SwanLabRLHFCompletionCallback" assert callback.log_interval == 50 assert callback.logger.maxlen == 64 def test_completion_callback_not_registered_for_non_rlhf_trainer(self): """Test that completion callback is NOT registered for non-RLHF trainers.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() plugin.swanlab_initialized = True cfg = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_log_completions": True, } cfg_obj = DictDefault(cfg) # Mock regular SFT trainer (not RLHF) mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlTrainer" # Not RLHF mock_trainer.state = MagicMock(max_steps=1000) mock_trainer.args = MagicMock() with patch("swanlab.config.update"): plugin.post_trainer_create(cfg_obj, mock_trainer) # Callback should NOT be added for non-RLHF trainer mock_trainer.add_callback.assert_not_called() def test_completion_callback_not_registered_when_disabled(self): """Test that completion callback is not registered when disabled in config.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() plugin.swanlab_initialized = True cfg = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_log_completions": False, # Disabled } cfg_obj = DictDefault(cfg) # Mock DPO trainer mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" mock_trainer.state = MagicMock(max_steps=1000) mock_trainer.args = MagicMock() with patch("swanlab.config.update"): plugin.post_trainer_create(cfg_obj, mock_trainer) # Callback should NOT be added when disabled mock_trainer.add_callback.assert_not_called() @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabProfiling: """Tests for SwanLab profiling utilities.""" def test_profiling_context_logs_duration(self): """Test that profiling context logs execution duration.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context # Mock trainer with SwanLab enabled mock_trainer = MagicMock() mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.__class__.__name__ = "TestTrainer" with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() # SwanLab initialized with swanlab_profiling_context(mock_trainer, "test_function"): time.sleep(0.01) # Simulate work # Verify log was called with correct metric name mock_log.assert_called_once() logged_data = mock_log.call_args[0][0] assert "profiling/Time taken: TestTrainer.test_function" in logged_data # Duration should be > 0.01 seconds assert ( logged_data["profiling/Time taken: TestTrainer.test_function"] >= 0.01 ) def test_profiling_context_skips_when_swanlab_disabled(self): """Test that profiling is skipped when SwanLab is disabled.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context mock_trainer = MagicMock() mock_trainer.cfg = MagicMock(use_swanlab=False) # Disabled with patch("swanlab.log") as mock_log: with swanlab_profiling_context(mock_trainer, "test_function"): time.sleep(0.01) # Should NOT log when disabled mock_log.assert_not_called() def test_profiling_context_skips_when_swanlab_not_initialized(self): """Test that profiling is skipped when SwanLab not initialized.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context mock_trainer = MagicMock() mock_trainer.cfg = MagicMock(use_swanlab=True) with ( patch("swanlab.get_run", return_value=None), patch("swanlab.log") as mock_log, ): with swanlab_profiling_context(mock_trainer, "test_function"): time.sleep(0.01) # Should NOT log when not initialized mock_log.assert_not_called() def test_profiling_decorator(self): """Test swanlab_profile decorator.""" from axolotl.integrations.swanlab.profiling import swanlab_profile class MockTrainer: def __init__(self): self.cfg = MagicMock(use_swanlab=True) @swanlab_profile def expensive_method(self, x): time.sleep(0.01) return x * 2 trainer = MockTrainer() with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() result = trainer.expensive_method(5) # Verify method still works correctly assert result == 10 # Verify profiling was logged mock_log.assert_called_once() logged_data = mock_log.call_args[0][0] assert "profiling/Time taken: MockTrainer.expensive_method" in logged_data def test_profiling_config(self): """Test ProfilingConfig class.""" from axolotl.integrations.swanlab.profiling import ProfilingConfig config = ProfilingConfig( enabled=True, min_duration_ms=1.0, log_interval=5, ) # Test enabled check assert config.enabled is True # Test minimum duration filtering assert config.should_log("func1", 0.0001) is False # 0.1ms < 1.0ms threshold assert config.should_log("func2", 0.002) is True # 2.0ms > 1.0ms threshold # Test log interval assert config.should_log("func3", 0.002) is True # 1st call assert config.should_log("func3", 0.002) is False # 2nd call assert config.should_log("func3", 0.002) is False # 3rd call assert config.should_log("func3", 0.002) is False # 4th call assert config.should_log("func3", 0.002) is True # 5th call (interval=5) def test_profiling_config_when_disabled(self): """Test ProfilingConfig when disabled.""" from axolotl.integrations.swanlab.profiling import ProfilingConfig config = ProfilingConfig(enabled=False) # Should never log when disabled assert config.should_log("func1", 100.0) is False def test_profiling_context_advanced(self): """Test advanced profiling context with custom config.""" from axolotl.integrations.swanlab.profiling import ( ProfilingConfig, swanlab_profiling_context_advanced, ) mock_trainer = MagicMock() mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.__class__.__name__ = "TestTrainer" # Config that filters out very fast operations config = ProfilingConfig(min_duration_ms=10.0) # 10ms minimum with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() # Fast operation (< 10ms) - should NOT log with swanlab_profiling_context_advanced(mock_trainer, "fast_op", config): time.sleep(0.001) # 1ms mock_log.assert_not_called() # Slow operation (> 10ms) - should log with swanlab_profiling_context_advanced(mock_trainer, "slow_op", config): time.sleep(0.015) # 15ms mock_log.assert_called_once() def test_profiling_with_exception(self): """Test that profiling still logs even when exception occurs.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context mock_trainer = MagicMock() mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.__class__.__name__ = "TestTrainer" with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() try: with swanlab_profiling_context(mock_trainer, "error_function"): time.sleep(0.01) raise ValueError("Test error") except ValueError: pass # Expected # Should still log duration even with exception mock_log.assert_called_once()