distributed fix

This commit is contained in:
Dan Saunders
2025-02-26 02:55:44 +00:00
parent 5a2a80cc48
commit b2f1fc109a
2 changed files with 51 additions and 28 deletions

View File

@@ -15,8 +15,6 @@ import psutil
import torch
import yaml
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
@@ -77,6 +75,43 @@ RELEVANT_PACKAGES = {
}
def is_main_process() -> bool:
"""
Check whether we're running in the main process.
Note:
We're using this function instead of `torch.utils.distributed.is_main_process`
causes issues with DeepSpeed world_size since. This function avoids that issue
by checking env vars that are set by various launchers.
Returns:
Whether we're running in the main process.
"""
# If PyTorch distributed is already initialized, use it
if torch.distributed.is_initialized():
return torch.distributed.get_rank() == 0
# Otherwise check environment variables for global rank
# NOTE: need to verify this in SLURM / OpenMPI environments
global_rank = int(
os.environ.get(
"RANK",
os.environ.get(
"GLOBAL_RANK",
os.environ.get(
"SLURM_PROCID",
os.environ.get(
"OMPI_COMM_WORLD_RANK",
"0",
),
),
),
)
)
return global_rank == 0
class TelemetryManager:
"""Manages telemetry collection and transmission"""