# Copyright 2024 Axolotl AI. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Unit tests for SwanLab Integration Plugin. Tests conflict detection, configuration validation, and multi-logger warnings. """ import importlib.util import logging import os import time from unittest.mock import MagicMock, patch import pytest from pydantic import ValidationError from axolotl.integrations.swanlab.args import SwanLabConfig from axolotl.integrations.swanlab.plugins import SwanLabPlugin SWANLAB_INSTALLED = importlib.util.find_spec("swanlab") is not None @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabConfigValidators: """Tests for Pydantic field validators in SwanLabConfig.""" def test_valid_swanlab_mode_cloud(self): """Test that 'cloud' mode is valid.""" config = SwanLabConfig(swanlab_mode="cloud") assert config.swanlab_mode == "cloud" def test_valid_swanlab_mode_local(self): """Test that 'local' mode is valid.""" config = SwanLabConfig(swanlab_mode="local") assert config.swanlab_mode == "local" def test_valid_swanlab_mode_offline(self): """Test that 'offline' mode is valid.""" config = SwanLabConfig(swanlab_mode="offline") assert config.swanlab_mode == "offline" def test_valid_swanlab_mode_disabled(self): """Test that 'disabled' mode is valid.""" config = SwanLabConfig(swanlab_mode="disabled") assert config.swanlab_mode == "disabled" def test_invalid_swanlab_mode(self): """Test that invalid mode raises ValueError.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(swanlab_mode="invalid") error_msg = str(exc_info.value) assert "Invalid swanlab_mode" in error_msg assert "cloud" in error_msg assert "local" in error_msg assert "offline" in error_msg assert "disabled" in error_msg def test_swanlab_mode_none_allowed(self): """Test that None mode is allowed (will use default).""" config = SwanLabConfig(swanlab_mode=None) assert config.swanlab_mode is None def test_valid_swanlab_project(self): """Test that valid project name is accepted.""" config = SwanLabConfig(swanlab_project="my-project") assert config.swanlab_project == "my-project" def test_swanlab_project_none_allowed(self): """Test that None project is allowed.""" config = SwanLabConfig(swanlab_project=None) assert config.swanlab_project is None def test_empty_swanlab_project_rejected(self): """Test that empty string project name is rejected.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(swanlab_project="") error_msg = str(exc_info.value) assert "cannot be an empty string" in error_msg def test_whitespace_only_project_rejected(self): """Test that whitespace-only project name is rejected.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(swanlab_project=" ") error_msg = str(exc_info.value) assert "cannot be an empty string" in error_msg def test_use_swanlab_true_requires_project(self): """Test that use_swanlab=True requires swanlab_project.""" with pytest.raises(ValidationError) as exc_info: SwanLabConfig(use_swanlab=True, swanlab_project=None) error_msg = str(exc_info.value) assert "swanlab_project" in error_msg.lower() assert "not set" in error_msg.lower() def test_use_swanlab_true_with_project_valid(self): """Test that use_swanlab=True with project is valid.""" config = SwanLabConfig(use_swanlab=True, swanlab_project="my-project") assert config.use_swanlab is True assert config.swanlab_project == "my-project" def test_use_swanlab_false_no_project_valid(self): """Test that use_swanlab=False without project is valid.""" config = SwanLabConfig(use_swanlab=False, swanlab_project=None) assert config.use_swanlab is False assert config.swanlab_project is None def test_use_swanlab_none_no_project_valid(self): """Test that use_swanlab=None without project is valid.""" config = SwanLabConfig(use_swanlab=None, swanlab_project=None) assert config.use_swanlab is None assert config.swanlab_project is None @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginRegister: """Tests for SwanLabPlugin.register() conflict detection.""" def test_register_without_use_swanlab(self): """Test that register works when SwanLab is not enabled.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": False} # Should not raise plugin.register(cfg) def test_register_use_swanlab_missing_project(self): """Test that use_swanlab=True without project raises ValueError.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": True} with pytest.raises(ValueError) as exc_info: plugin.register(cfg) error_msg = str(exc_info.value) assert "swanlab_project" in error_msg assert "not set" in error_msg assert "Solutions" in error_msg def test_register_use_swanlab_with_project_valid(self): """Test that use_swanlab=True with project is valid.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": True, "swanlab_project": "my-project"} # Should not raise plugin.register(cfg) def test_register_invalid_mode(self): """Test that invalid swanlab_mode raises ValueError.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "swanlab_mode": "invalid-mode", } with pytest.raises(ValueError) as exc_info: plugin.register(cfg) error_msg = str(exc_info.value) assert "Invalid swanlab_mode" in error_msg assert "cloud" in error_msg assert "local" in error_msg def test_register_valid_modes(self): """Test that all valid modes are accepted.""" plugin = SwanLabPlugin() valid_modes = ["cloud", "local", "offline", "disabled"] for mode in valid_modes: cfg = { "use_swanlab": True, "swanlab_project": "my-project", "swanlab_mode": mode, } # Should not raise plugin.register(cfg) def test_register_auto_enable_swanlab(self): """Test that providing swanlab_project auto-enables use_swanlab.""" plugin = SwanLabPlugin() cfg = {"swanlab_project": "my-project"} plugin.register(cfg) assert cfg["use_swanlab"] is True def test_register_cloud_mode_without_api_key_warns(self, caplog): """Test that cloud mode without API key logs warning.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "swanlab_mode": "cloud", } # Clear environment variable to ensure it's not set with patch.dict(os.environ, {}, clear=True): with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should log warning about missing API key warning_messages = [record.message for record in caplog.records] assert any("API key" in msg for msg in warning_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestMultiLoggerDetection: """Tests for multi-logger conflict detection.""" def test_single_logger_no_warning(self, caplog): """Test that single logger doesn't trigger warning.""" plugin = SwanLabPlugin() cfg = {"use_swanlab": True, "swanlab_project": "my-project"} with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should not log multi-logger warning warning_messages = [record.message for record in caplog.records] assert not any("Multiple logging tools" in msg for msg in warning_messages) def test_two_loggers_warning(self, caplog): """Test that two loggers trigger warning.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "use_wandb": True, } with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should log multi-logger warning warning_messages = [record.message for record in caplog.records] assert any("Multiple logging tools" in msg for msg in warning_messages) assert any("SwanLab" in msg and "WandB" in msg for msg in warning_messages) def test_three_loggers_error(self, caplog): """Test that three loggers trigger error-level warning.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "use_wandb": True, "use_mlflow": True, } with caplog.at_level(logging.ERROR): plugin.register(cfg) # Should log error-level warning error_messages = [ record.message for record in caplog.records if record.levelno >= logging.ERROR ] assert any("logging tools enabled" in msg for msg in error_messages) def test_multi_logger_with_comet(self, caplog): """Test that Comet is detected in multi-logger scenario.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "comet_api_key": "test-key", } with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should detect Comet warning_messages = [record.message for record in caplog.records] assert any("Comet" in msg for msg in warning_messages) def test_multi_logger_with_comet_project(self, caplog): """Test that Comet is detected via comet_project_name.""" plugin = SwanLabPlugin() cfg = { "use_swanlab": True, "swanlab_project": "my-project", "comet_project_name": "test-project", } with caplog.at_level(logging.WARNING): plugin.register(cfg) # Should detect Comet warning_messages = [record.message for record in caplog.records] assert any("Comet" in msg for msg in warning_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginPreModelLoad: """Tests for SwanLabPlugin.pre_model_load() runtime checks.""" def test_pre_model_load_disabled(self): """Test that pre_model_load does nothing when SwanLab is disabled.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = False # Should not raise plugin.pre_model_load(cfg) def test_pre_model_load_import_error(self): """Test that missing swanlab package raises clear ImportError.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True with patch( "builtins.__import__", side_effect=ImportError("No module named 'swanlab'") ): with pytest.raises(ImportError) as exc_info: plugin.pre_model_load(cfg) error_msg = str(exc_info.value) assert "SwanLab is not installed" in error_msg assert "pip install swanlab" in error_msg @patch("axolotl.utils.distributed.is_main_process") @patch("axolotl.utils.distributed.get_world_size") def test_pre_model_load_non_main_process_skips( self, mock_get_world_size, mock_is_main_process ): """Test that non-main process skips SwanLab initialization.""" mock_get_world_size.return_value = 2 mock_is_main_process.return_value = False plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True with patch("swanlab.init") as mock_init: plugin.pre_model_load(cfg) # Should NOT call swanlab.init mock_init.assert_not_called() @patch("axolotl.utils.distributed.is_main_process") @patch("axolotl.utils.distributed.get_world_size") def test_pre_model_load_distributed_logging( self, mock_get_world_size, mock_is_main_process, caplog ): """Test that distributed training logs world size info.""" mock_get_world_size.return_value = 4 mock_is_main_process.return_value = True plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "cloud" with patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"): with caplog.at_level(logging.INFO): plugin.pre_model_load(cfg) # Should log distributed training info info_messages = [record.message for record in caplog.records] assert any("world_size=4" in msg for msg in info_messages) assert any("Only rank 0" in msg for msg in info_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabInitKwargs: """Tests for SwanLab initialization with direct parameter passing.""" def test_custom_branding_added_to_config(self): """Test that Axolotl custom branding is added to SwanLab config.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "test-project", } ) init_kwargs = plugin._get_swanlab_init_kwargs(cfg) # Verify custom branding is present assert "config" in init_kwargs assert init_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl" def test_api_key_passed_directly(self): """Test that API key is passed directly to swanlab.init() instead of via env var.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_api_key": "test-api-key-12345", } ) init_kwargs = plugin._get_swanlab_init_kwargs(cfg) # Verify API key is in init_kwargs (not set as env var) assert "api_key" in init_kwargs assert init_kwargs["api_key"] == "test-api-key-12345" def test_private_deployment_hosts_passed_directly(self): """Test that private deployment hosts are passed directly to swanlab.init().""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "internal-project", "swanlab_web_host": "https://swanlab.company.com", "swanlab_api_host": "https://api-swanlab.company.com", } ) init_kwargs = plugin._get_swanlab_init_kwargs(cfg) # Verify private deployment hosts are in init_kwargs assert "web_host" in init_kwargs assert init_kwargs["web_host"] == "https://swanlab.company.com" assert "api_host" in init_kwargs assert init_kwargs["api_host"] == "https://api-swanlab.company.com" @patch("axolotl.utils.distributed.is_main_process") def test_full_private_deployment_init(self, mock_is_main_process): """Test complete initialization with private deployment configuration.""" mock_is_main_process.return_value = True from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "secure-project", "swanlab_experiment_name": "experiment-001", "swanlab_mode": "cloud", "swanlab_api_key": "private-key-xyz", "swanlab_web_host": "https://swanlab.internal.net", "swanlab_api_host": "https://api.swanlab.internal.net", "swanlab_workspace": "research-team", } ) with patch("swanlab.init") as mock_init: plugin.pre_model_load(cfg) # Verify swanlab.init was called with all parameters mock_init.assert_called_once() call_kwargs = mock_init.call_args[1] assert call_kwargs["project"] == "secure-project" assert call_kwargs["experiment_name"] == "experiment-001" assert call_kwargs["mode"] == "cloud" assert call_kwargs["api_key"] == "private-key-xyz" assert call_kwargs["web_host"] == "https://swanlab.internal.net" assert call_kwargs["api_host"] == "https://api.swanlab.internal.net" assert call_kwargs["workspace"] == "research-team" assert call_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl" def test_env_vars_not_set_for_api_params(self): """Test that environment variables are NOT set for API parameters.""" import os from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault # Clear any existing env vars for key in [ "SWANLAB_API_KEY", "SWANLAB_WEB_HOST", "SWANLAB_API_HOST", "SWANLAB_MODE", ]: os.environ.pop(key, None) plugin = SwanLabPlugin() cfg = DictDefault( { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_api_key": "test-key", "swanlab_web_host": "https://test.com", "swanlab_api_host": "https://api-test.com", "swanlab_mode": "cloud", } ) with ( patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("swanlab.init"), ): plugin.pre_model_load(cfg) # Verify env vars were NOT set (simplified approach) # The old _setup_swanlab_env() method is removed, so these shouldn't be set # Note: SwanLab itself might set these, but our plugin shouldn't # We're just testing that our plugin doesn't call _setup_swanlab_env() @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestLarkNotificationIntegration: """Tests for Lark (Feishu) notification integration.""" def test_lark_callback_registration_with_webhook_only(self): """Test Lark callback registration with webhook URL only (no secret).""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = None with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): # Mock LarkCallback import with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: mock_lark_instance = MagicMock() MockLarkCallback.return_value = mock_lark_instance plugin.pre_model_load(cfg) # Verify LarkCallback was instantiated with correct params MockLarkCallback.assert_called_once_with( webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook", secret=None, ) # Verify callback was registered mock_register.assert_called_once_with([mock_lark_instance]) def test_lark_callback_registration_with_secret(self): """Test Lark callback registration with webhook URL and HMAC secret.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = "test-hmac-secret" with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: mock_lark_instance = MagicMock() MockLarkCallback.return_value = mock_lark_instance plugin.pre_model_load(cfg) # Verify LarkCallback was instantiated with secret MockLarkCallback.assert_called_once_with( webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook", secret="test-hmac-secret", ) mock_register.assert_called_once_with([mock_lark_instance]) def test_lark_callback_not_registered_without_webhook(self): """Test that Lark callback is NOT registered when webhook URL not provided.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = None # No webhook cfg.swanlab_lark_secret = None with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): plugin.pre_model_load(cfg) # Verify register_callbacks was NOT called mock_register.assert_not_called() def test_lark_import_error_handled_gracefully(self, caplog): """Test that ImportError for Lark plugin is handled gracefully.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = None with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): # Mock ImportError for LarkCallback with patch( "swanlab.plugin.notification.LarkCallback", side_effect=ImportError( "No module named 'swanlab.plugin.notification'" ), ): with caplog.at_level(logging.WARNING): plugin.pre_model_load(cfg) # Should log warning about missing Lark plugin warning_messages = [record.message for record in caplog.records] assert any( "Failed to import SwanLab Lark plugin" in msg for msg in warning_messages ) assert any("SwanLab >= 0.3.0" in msg for msg in warning_messages) def test_lark_warning_for_missing_secret(self, caplog): """Test that warning is logged when Lark webhook has no HMAC secret.""" plugin = SwanLabPlugin() cfg = MagicMock() cfg.use_swanlab = True cfg.swanlab_project = "test-project" cfg.swanlab_mode = "local" cfg.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" ) cfg.swanlab_lark_secret = None # No secret with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks"), patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): with patch("swanlab.plugin.notification.LarkCallback"): with caplog.at_level(logging.WARNING): plugin.pre_model_load(cfg) # Should log warning about missing secret warning_messages = [record.message for record in caplog.records] assert any( "no secret configured" in msg.lower() for msg in warning_messages ) assert any("swanlab_lark_secret" in msg for msg in warning_messages) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginIntegration: """Integration tests for SwanLab plugin lifecycle.""" def test_full_lifecycle_valid_config(self): """Test full plugin lifecycle with valid configuration.""" plugin = SwanLabPlugin() # Register cfg_dict = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_mode": "local", } plugin.register(cfg_dict) # Pre-model load (mock SwanLab) cfg_obj = MagicMock() cfg_obj.use_swanlab = True cfg_obj.swanlab_project = "test-project" cfg_obj.swanlab_mode = "local" cfg_obj.swanlab_lark_webhook_url = None # No Lark with ( patch("swanlab.init") as mock_init, patch("swanlab.__version__", "0.3.0"), patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): plugin.pre_model_load(cfg_obj) # Should call swanlab.init mock_init.assert_called_once() def test_lifecycle_with_multi_logger_warning(self, caplog): """Test lifecycle with multi-logger warning.""" plugin = SwanLabPlugin() cfg_dict = { "use_swanlab": True, "swanlab_project": "test-project", "use_wandb": True, } with caplog.at_level(logging.WARNING): plugin.register(cfg_dict) # Should have multi-logger warning warning_messages = [record.message for record in caplog.records] assert any("Multiple logging tools" in msg for msg in warning_messages) def test_lifecycle_invalid_config_fails_early(self): """Test that invalid config fails at register stage.""" plugin = SwanLabPlugin() cfg_dict = { "use_swanlab": True, # Missing swanlab_project } # Should fail at register, not pre_model_load with pytest.raises(ValueError): plugin.register(cfg_dict) def test_full_lifecycle_with_lark_notifications(self): """Test full lifecycle including Lark notification registration.""" plugin = SwanLabPlugin() # Register cfg_dict = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_mode": "cloud", } plugin.register(cfg_dict) # Pre-model load with Lark config cfg_obj = MagicMock() cfg_obj.use_swanlab = True cfg_obj.swanlab_project = "test-project" cfg_obj.swanlab_mode = "cloud" cfg_obj.swanlab_lark_webhook_url = ( "https://open.feishu.cn/open-apis/bot/v2/hook/test" ) cfg_obj.swanlab_lark_secret = "secret123" with ( patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"), patch("swanlab.register_callbacks") as mock_register, patch("axolotl.utils.distributed.is_main_process", return_value=True), patch("axolotl.utils.distributed.get_world_size", return_value=1), ): with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: mock_lark_instance = MagicMock() MockLarkCallback.return_value = mock_lark_instance plugin.pre_model_load(cfg_obj) # Verify both SwanLab init AND Lark callback registration MockLarkCallback.assert_called_once() mock_register.assert_called_once_with([mock_lark_instance]) @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestCompletionLogger: """Tests for CompletionLogger utility class.""" def test_completion_logger_initialization(self): """Test CompletionLogger initializes with correct maxlen.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=64) assert logger.maxlen == 64 assert len(logger) == 0 def test_add_dpo_completion(self): """Test adding DPO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="What is AI?", chosen="Artificial Intelligence is...", rejected="AI means...", reward_diff=0.5, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 0 assert entry["prompt"] == "What is AI?" assert entry["chosen"] == "Artificial Intelligence is..." assert entry["rejected"] == "AI means..." assert entry["reward_diff"] == 0.5 def test_add_kto_completion(self): """Test adding KTO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_kto_completion( step=1, prompt="Explain quantum physics", completion="Quantum physics is...", label=True, reward=0.8, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 1 assert entry["prompt"] == "Explain quantum physics" assert entry["completion"] == "Quantum physics is..." assert entry["label"] == "desirable" assert entry["reward"] == 0.8 def test_add_orpo_completion(self): """Test adding ORPO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_orpo_completion( step=2, prompt="Write a poem", chosen="Roses are red...", rejected="Violets are blue...", log_odds_ratio=1.2, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 2 assert entry["chosen"] == "Roses are red..." assert entry["rejected"] == "Violets are blue..." assert entry["log_odds_ratio"] == 1.2 def test_add_grpo_completion(self): """Test adding GRPO completions to buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_grpo_completion( step=3, prompt="Solve this problem", completion="The answer is 42", reward=0.9, advantage=0.3, ) assert len(logger) == 1 entry = logger.data[0] assert entry["step"] == 3 assert entry["completion"] == "The answer is 42" assert entry["reward"] == 0.9 assert entry["advantage"] == 0.3 def test_memory_bounded_buffer(self): """Test that buffer respects maxlen and drops oldest entries.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=3) # Add 5 completions for i in range(5): logger.add_dpo_completion( step=i, prompt=f"Prompt {i}", chosen=f"Chosen {i}", rejected=f"Rejected {i}", ) # Should only keep last 3 assert len(logger) == 3 assert logger.data[0]["step"] == 2 # Oldest kept assert logger.data[1]["step"] == 3 assert logger.data[2]["step"] == 4 # Newest def test_log_to_swanlab_when_not_initialized(self): """Test logging gracefully fails when SwanLab not initialized.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) with patch("swanlab.get_run", return_value=None): result = logger.log_to_swanlab() assert result is False # Should fail gracefully def test_log_to_swanlab_success(self): """Test successful logging to SwanLab.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="Test prompt", chosen="Chosen response", rejected="Rejected response", reward_diff=0.5, ) with ( patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log, patch("swanlab.echarts.Table") as MockTable, ): mock_get_run.return_value = MagicMock() # SwanLab initialized mock_table_instance = MagicMock() MockTable.return_value = mock_table_instance result = logger.log_to_swanlab(table_name="test_table") assert result is True mock_log.assert_called_once() mock_table_instance.add.assert_called_once() def test_clear_buffer(self): """Test clearing the completion buffer.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=10) logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) assert len(logger) == 1 logger.clear() assert len(logger) == 0 def test_repr(self): """Test string representation.""" from axolotl.integrations.swanlab.completion_logger import CompletionLogger logger = CompletionLogger(maxlen=128) logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) repr_str = repr(logger) assert "CompletionLogger" in repr_str assert "maxlen=128" in repr_str assert "buffered=1/128" in repr_str @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabRLHFCompletionCallback: """Tests for SwanLabRLHFCompletionCallback.""" def test_callback_initialization(self): """Test callback initializes with correct parameters.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback( log_interval=50, max_completions=64, table_name="custom_table", ) assert callback.log_interval == 50 assert callback.logger.maxlen == 64 assert callback.table_name == "custom_table" assert callback.trainer_type is None def test_trainer_type_detection_dpo(self): """Test DPO trainer type is detected correctly.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback() # Mock trainer with DPO in name mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" callback.on_init_end( args=MagicMock(), state=MagicMock(), control=MagicMock(), trainer=mock_trainer, ) assert callback.trainer_type == "dpo" def test_trainer_type_detection_kto(self): """Test KTO trainer type is detected correctly.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback() mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlKTOTrainer" callback.on_init_end( args=MagicMock(), state=MagicMock(), control=MagicMock(), trainer=mock_trainer, ) assert callback.trainer_type == "kto" def test_on_train_end_logs_completions(self): """Test that completions are logged at end of training.""" from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback callback = SwanLabRLHFCompletionCallback() callback.trainer_type = "dpo" # Add some completions to buffer callback.logger.add_dpo_completion( step=0, prompt="Test", chosen="A", rejected="B", ) with patch.object(callback.logger, "log_to_swanlab") as mock_log: callback.on_train_end( args=MagicMock(), state=MagicMock(global_step=100), control=MagicMock(), ) # Should log remaining completions mock_log.assert_called_once() @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabPluginCompletionIntegration: """Integration tests for completion logging in SwanLabPlugin.""" def test_completion_callback_registered_for_dpo_trainer(self): """Test that completion callback is registered for DPO trainer.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() plugin.swanlab_initialized = True # Simulate SwanLab initialized cfg = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_log_completions": True, "swanlab_completion_log_interval": 50, "swanlab_completion_max_buffer": 64, } cfg_obj = DictDefault(cfg) # Mock DPO trainer mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" mock_trainer.state = MagicMock(max_steps=1000) mock_trainer.args = MagicMock( num_train_epochs=3, train_batch_size=4, gradient_accumulation_steps=2, ) with patch("swanlab.config.update"): plugin.post_trainer_create(cfg_obj, mock_trainer) # Verify callback was added mock_trainer.add_callback.assert_called_once() callback = mock_trainer.add_callback.call_args[0][0] assert callback.__class__.__name__ == "SwanLabRLHFCompletionCallback" assert callback.log_interval == 50 assert callback.logger.maxlen == 64 def test_completion_callback_not_registered_for_non_rlhf_trainer(self): """Test that completion callback is NOT registered for non-RLHF trainers.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() plugin.swanlab_initialized = True cfg = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_log_completions": True, } cfg_obj = DictDefault(cfg) # Mock regular SFT trainer (not RLHF) mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlTrainer" # Not RLHF mock_trainer.state = MagicMock(max_steps=1000) mock_trainer.args = MagicMock() with patch("swanlab.config.update"): plugin.post_trainer_create(cfg_obj, mock_trainer) # Callback should NOT be added for non-RLHF trainer mock_trainer.add_callback.assert_not_called() def test_completion_callback_not_registered_when_disabled(self): """Test that completion callback is not registered when disabled in config.""" from axolotl.integrations.swanlab.plugins import SwanLabPlugin from axolotl.utils.dict import DictDefault plugin = SwanLabPlugin() plugin.swanlab_initialized = True cfg = { "use_swanlab": True, "swanlab_project": "test-project", "swanlab_log_completions": False, # Disabled } cfg_obj = DictDefault(cfg) # Mock DPO trainer mock_trainer = MagicMock() mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" mock_trainer.state = MagicMock(max_steps=1000) mock_trainer.args = MagicMock() with patch("swanlab.config.update"): plugin.post_trainer_create(cfg_obj, mock_trainer) # Callback should NOT be added when disabled mock_trainer.add_callback.assert_not_called() @pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") class TestSwanLabProfiling: """Tests for SwanLab profiling utilities.""" def test_profiling_context_logs_duration(self): """Test that profiling context logs execution duration.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context # Mock trainer with SwanLab enabled mock_trainer = MagicMock() mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True) mock_trainer.__class__.__name__ = "TestTrainer" with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() # SwanLab initialized with swanlab_profiling_context(mock_trainer, "test_function"): time.sleep(0.01) # Simulate work # Verify log was called with correct metric name mock_log.assert_called_once() logged_data = mock_log.call_args[0][0] assert "profiling/Time taken: TestTrainer.test_function" in logged_data # Duration should be > 0.01 seconds assert ( logged_data["profiling/Time taken: TestTrainer.test_function"] >= 0.01 ) def test_profiling_context_skips_when_swanlab_disabled(self): """Test that profiling is skipped when SwanLab is disabled.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context mock_trainer = MagicMock() mock_trainer.axolotl_cfg = MagicMock(use_swanlab=False) # Disabled with patch("swanlab.log") as mock_log: with swanlab_profiling_context(mock_trainer, "test_function"): time.sleep(0.01) # Should NOT log when disabled mock_log.assert_not_called() def test_profiling_context_skips_when_swanlab_not_initialized(self): """Test that profiling is skipped when SwanLab not initialized.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context mock_trainer = MagicMock() mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True) with ( patch("swanlab.get_run", return_value=None), patch("swanlab.log") as mock_log, ): with swanlab_profiling_context(mock_trainer, "test_function"): time.sleep(0.01) # Should NOT log when not initialized mock_log.assert_not_called() def test_profiling_decorator(self): """Test swanlab_profile decorator.""" from axolotl.integrations.swanlab.profiling import swanlab_profile class MockTrainer: def __init__(self): self.cfg = MagicMock(use_swanlab=True) @swanlab_profile def expensive_method(self, x): time.sleep(0.01) return x * 2 trainer = MockTrainer() with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() result = trainer.expensive_method(5) # Verify method still works correctly assert result == 10 # Verify profiling was logged mock_log.assert_called_once() logged_data = mock_log.call_args[0][0] assert "profiling/Time taken: MockTrainer.expensive_method" in logged_data def test_profiling_config(self): """Test ProfilingConfig class.""" from axolotl.integrations.swanlab.profiling import ProfilingConfig config = ProfilingConfig( enabled=True, min_duration_ms=1.0, log_interval=5, ) # Test enabled check assert config.enabled is True # Test minimum duration filtering assert config.should_log("func1", 0.0001) is False # 0.1ms < 1.0ms threshold assert config.should_log("func2", 0.002) is True # 2.0ms > 1.0ms threshold # Test log interval assert config.should_log("func3", 0.002) is True # 1st call assert config.should_log("func3", 0.002) is False # 2nd call assert config.should_log("func3", 0.002) is False # 3rd call assert config.should_log("func3", 0.002) is False # 4th call assert config.should_log("func3", 0.002) is True # 5th call (interval=5) def test_profiling_config_when_disabled(self): """Test ProfilingConfig when disabled.""" from axolotl.integrations.swanlab.profiling import ProfilingConfig config = ProfilingConfig(enabled=False) # Should never log when disabled assert config.should_log("func1", 100.0) is False def test_profiling_context_advanced(self): """Test advanced profiling context with custom config.""" from axolotl.integrations.swanlab.profiling import ( ProfilingConfig, swanlab_profiling_context_advanced, ) mock_trainer = MagicMock() mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True) mock_trainer.__class__.__name__ = "TestTrainer" # Config that filters out very fast operations config = ProfilingConfig(min_duration_ms=10.0) # 10ms minimum with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() # Fast operation (< 10ms) - should NOT log with swanlab_profiling_context_advanced(mock_trainer, "fast_op", config): time.sleep(0.001) # 1ms mock_log.assert_not_called() # Slow operation (> 10ms) - should log with swanlab_profiling_context_advanced(mock_trainer, "slow_op", config): time.sleep(0.015) # 15ms mock_log.assert_called_once() def test_profiling_with_exception(self): """Test that profiling still logs even when exception occurs.""" from axolotl.integrations.swanlab.profiling import swanlab_profiling_context mock_trainer = MagicMock() mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True) mock_trainer.__class__.__name__ = "TestTrainer" with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: mock_get_run.return_value = MagicMock() try: with swanlab_profiling_context(mock_trainer, "error_function"): time.sleep(0.01) raise ValueError("Test error") except ValueError: pass # Expected # Should still log duration even with exception mock_log.assert_called_once()