diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index 64da12ce1..71bb5ca2a 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -15,8 +15,6 @@ import psutil import torch import yaml -from axolotl.utils.distributed import is_main_process - LOG = logging.getLogger(__name__) POSTHOG_HOST = "https://app.posthog.com" @@ -77,6 +75,43 @@ RELEVANT_PACKAGES = { } +def is_main_process() -> bool: + """ + Check whether we're running in the main process. + + Note: + We're using this function instead of `torch.utils.distributed.is_main_process` + causes issues with DeepSpeed world_size since. This function avoids that issue + by checking env vars that are set by various launchers. + + Returns: + Whether we're running in the main process. + """ + # If PyTorch distributed is already initialized, use it + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 + + # Otherwise check environment variables for global rank + # NOTE: need to verify this in SLURM / OpenMPI environments + global_rank = int( + os.environ.get( + "RANK", + os.environ.get( + "GLOBAL_RANK", + os.environ.get( + "SLURM_PROCID", + os.environ.get( + "OMPI_COMM_WORLD_RANK", + "0", + ), + ), + ), + ) + ) + + return global_rank == 0 + + class TelemetryManager: """Manages telemetry collection and transmission""" diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py index 441d94a10..e01ab9339 100644 --- a/tests/telemetry/test_manager.py +++ b/tests/telemetry/test_manager.py @@ -40,7 +40,7 @@ def manager(telemetry_manager_class, mock_whitelist): """Create a TelemetryManager instance with mocked dependencies""" with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch( "axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist - ), patch("axolotl.telemetry.manager.is_main_process", return_value=True): + ), patch.dict(os.environ, {"RANK": "0"}): manager = telemetry_manager_class() # Manually enable for most tests manager.enabled = True @@ -49,7 +49,9 @@ def manager(telemetry_manager_class, mock_whitelist): def test_singleton_instance(telemetry_manager_class): """Test that TelemetryManager is a singleton""" - with patch("posthog.capture"), patch("time.sleep"): + with patch("posthog.capture"), patch("time.sleep"), patch.dict( + os.environ, {"RANK": "0"} + ): first = telemetry_manager_class() second = telemetry_manager_class() assert first is second @@ -58,36 +60,30 @@ def test_singleton_instance(telemetry_manager_class): def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class): """Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1""" - with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1"}), patch( - "axolotl.telemetry.manager.is_main_process", return_value=True - ): + with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}): manager = telemetry_manager_class() assert not manager.enabled def test_telemetry_disabled_with_do_not_track(telemetry_manager_class): """Test that telemetry is disabled when DO_NOT_TRACK=1""" - with patch.dict(os.environ, {"DO_NOT_TRACK": "1"}), patch( - "axolotl.telemetry.manager.is_main_process", return_value=True - ): + with patch.dict(os.environ, {"DO_NOT_TRACK": "1", "RANK": "0"}): manager = telemetry_manager_class() assert not manager.enabled def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): """Test that telemetry is disabled for non-main processes""" - with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch( - "axolotl.telemetry.manager.is_main_process", return_value=False - ): + with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}): manager = telemetry_manager_class() assert not manager.enabled def test_telemetry_enabled_by_default(telemetry_manager_class): """Test that telemetry is enabled by default""" - with patch.dict(os.environ, {}, clear=True), patch( - "axolotl.telemetry.manager.is_main_process", return_value=True - ), patch("time.sleep"), patch("logging.Logger.warning"): + with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch( + "logging.Logger.warning" + ): manager = telemetry_manager_class() assert manager.enabled assert not manager.explicit_enable @@ -95,13 +91,9 @@ def test_telemetry_enabled_by_default(telemetry_manager_class): def test_explicit_enable_disables_warning(telemetry_manager_class): """Test that explicit enabling prevents warning""" - with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch( + with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch( "logging.Logger.warning" - ) as mock_warning, patch( - "axolotl.telemetry.manager.is_main_process", return_value=True - ), patch( - "time.sleep" - ): + ) as mock_warning, patch("time.sleep"): manager = telemetry_manager_class() assert manager.enabled assert manager.explicit_enable @@ -111,13 +103,9 @@ def test_explicit_enable_disables_warning(telemetry_manager_class): def test_warning_displayed_for_implicit_enable(telemetry_manager_class): """Test that warning is displayed when telemetry is implicitly enabled""" - with patch.dict(os.environ, {}, clear=True), patch( + with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch( "logging.Logger.warning" - ) as mock_warning, patch( - "axolotl.telemetry.manager.is_main_process", return_value=True - ), patch( - "time.sleep" - ): + ) as mock_warning, patch("time.sleep"): manager = telemetry_manager_class() assert manager.enabled assert not manager.explicit_enable