distributed fix
This commit is contained in:
@@ -15,8 +15,6 @@ import psutil
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
POSTHOG_HOST = "https://app.posthog.com"
|
||||
@@ -77,6 +75,43 @@ RELEVANT_PACKAGES = {
|
||||
}
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Check whether we're running in the main process.
|
||||
|
||||
Note:
|
||||
We're using this function instead of `torch.utils.distributed.is_main_process`
|
||||
causes issues with DeepSpeed world_size since. This function avoids that issue
|
||||
by checking env vars that are set by various launchers.
|
||||
|
||||
Returns:
|
||||
Whether we're running in the main process.
|
||||
"""
|
||||
# If PyTorch distributed is already initialized, use it
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
# Otherwise check environment variables for global rank
|
||||
# NOTE: need to verify this in SLURM / OpenMPI environments
|
||||
global_rank = int(
|
||||
os.environ.get(
|
||||
"RANK",
|
||||
os.environ.get(
|
||||
"GLOBAL_RANK",
|
||||
os.environ.get(
|
||||
"SLURM_PROCID",
|
||||
os.environ.get(
|
||||
"OMPI_COMM_WORLD_RANK",
|
||||
"0",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return global_rank == 0
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""Manages telemetry collection and transmission"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user