distributed fix
This commit is contained in:
@@ -15,8 +15,6 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
POSTHOG_HOST = "https://app.posthog.com"
|
POSTHOG_HOST = "https://app.posthog.com"
|
||||||
@@ -77,6 +75,43 @@ RELEVANT_PACKAGES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process() -> bool:
|
||||||
|
"""
|
||||||
|
Check whether we're running in the main process.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
We're using this function instead of `torch.utils.distributed.is_main_process`
|
||||||
|
causes issues with DeepSpeed world_size since. This function avoids that issue
|
||||||
|
by checking env vars that are set by various launchers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether we're running in the main process.
|
||||||
|
"""
|
||||||
|
# If PyTorch distributed is already initialized, use it
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return torch.distributed.get_rank() == 0
|
||||||
|
|
||||||
|
# Otherwise check environment variables for global rank
|
||||||
|
# NOTE: need to verify this in SLURM / OpenMPI environments
|
||||||
|
global_rank = int(
|
||||||
|
os.environ.get(
|
||||||
|
"RANK",
|
||||||
|
os.environ.get(
|
||||||
|
"GLOBAL_RANK",
|
||||||
|
os.environ.get(
|
||||||
|
"SLURM_PROCID",
|
||||||
|
os.environ.get(
|
||||||
|
"OMPI_COMM_WORLD_RANK",
|
||||||
|
"0",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return global_rank == 0
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
class TelemetryManager:
|
||||||
"""Manages telemetry collection and transmission"""
|
"""Manages telemetry collection and transmission"""
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def manager(telemetry_manager_class, mock_whitelist):
|
|||||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||||
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
|
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
|
||||||
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
|
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
|
||||||
), patch("axolotl.telemetry.manager.is_main_process", return_value=True):
|
), patch.dict(os.environ, {"RANK": "0"}):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
# Manually enable for most tests
|
# Manually enable for most tests
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
@@ -49,7 +49,9 @@ def manager(telemetry_manager_class, mock_whitelist):
|
|||||||
|
|
||||||
def test_singleton_instance(telemetry_manager_class):
|
def test_singleton_instance(telemetry_manager_class):
|
||||||
"""Test that TelemetryManager is a singleton"""
|
"""Test that TelemetryManager is a singleton"""
|
||||||
with patch("posthog.capture"), patch("time.sleep"):
|
with patch("posthog.capture"), patch("time.sleep"), patch.dict(
|
||||||
|
os.environ, {"RANK": "0"}
|
||||||
|
):
|
||||||
first = telemetry_manager_class()
|
first = telemetry_manager_class()
|
||||||
second = telemetry_manager_class()
|
second = telemetry_manager_class()
|
||||||
assert first is second
|
assert first is second
|
||||||
@@ -58,36 +60,30 @@ def test_singleton_instance(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1"}), patch(
|
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}):
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=True
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||||
with patch.dict(os.environ, {"DO_NOT_TRACK": "1"}), patch(
|
with patch.dict(os.environ, {"DO_NOT_TRACK": "1", "RANK": "0"}):
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=True
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled for non-main processes"""
|
"""Test that telemetry is disabled for non-main processes"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch(
|
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}):
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=False
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
||||||
"""Test that telemetry is enabled by default"""
|
"""Test that telemetry is enabled by default"""
|
||||||
with patch.dict(os.environ, {}, clear=True), patch(
|
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch(
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=True
|
"logging.Logger.warning"
|
||||||
), patch("time.sleep"), patch("logging.Logger.warning"):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert manager.enabled
|
assert manager.enabled
|
||||||
assert not manager.explicit_enable
|
assert not manager.explicit_enable
|
||||||
@@ -95,13 +91,9 @@ def test_telemetry_enabled_by_default(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_explicit_enable_disables_warning(telemetry_manager_class):
|
def test_explicit_enable_disables_warning(telemetry_manager_class):
|
||||||
"""Test that explicit enabling prevents warning"""
|
"""Test that explicit enabling prevents warning"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch(
|
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch(
|
||||||
"logging.Logger.warning"
|
"logging.Logger.warning"
|
||||||
) as mock_warning, patch(
|
) as mock_warning, patch("time.sleep"):
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=True
|
|
||||||
), patch(
|
|
||||||
"time.sleep"
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert manager.enabled
|
assert manager.enabled
|
||||||
assert manager.explicit_enable
|
assert manager.explicit_enable
|
||||||
@@ -111,13 +103,9 @@ def test_explicit_enable_disables_warning(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
|
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
|
||||||
"""Test that warning is displayed when telemetry is implicitly enabled"""
|
"""Test that warning is displayed when telemetry is implicitly enabled"""
|
||||||
with patch.dict(os.environ, {}, clear=True), patch(
|
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch(
|
||||||
"logging.Logger.warning"
|
"logging.Logger.warning"
|
||||||
) as mock_warning, patch(
|
) as mock_warning, patch("time.sleep"):
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=True
|
|
||||||
), patch(
|
|
||||||
"time.sleep"
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert manager.enabled
|
assert manager.enabled
|
||||||
assert not manager.explicit_enable
|
assert not manager.explicit_enable
|
||||||
|
|||||||
Reference in New Issue
Block a user