use can_device_access_peer for P2P checks (#3209) [skip ci]

* use can_device_access_peer for P2P checks

* also log warn when automatically setting NCCL_P2P_DISABLE=1
This commit is contained in:
Wing Lian
2025-10-09 14:17:31 -04:00
committed by GitHub
parent 37f78c8592
commit 3a5c97e6e5
2 changed files with 21 additions and 40 deletions

View File

@@ -3,66 +3,46 @@ utils to get GPU info for the current environment
"""
import os
import subprocess # nosec B404
from importlib.metadata import version
import torch
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
get_gpu_info,
)
from packaging.version import Version, parse
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
if not check_runpod_p2p_support():
if not check_cuda_p2p_support():
return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception: # nosec B110
pass
return True
def check_runpod_p2p_support() -> bool:
if "RUNPOD_GPU_COUNT" not in os.environ:
return True
def check_cuda_p2p_support() -> bool:
try:
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
except ValueError:
return True
if gpu_count >= 2:
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
if world_size > 1:
node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8"))
local_other_rank = (local_rank // node_world_size) * node_world_size
local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0
try:
result = subprocess.run( # nosec B603 B607
["nvidia-smi", "topo", "-p2p", "n"],
check=True,
capture_output=True,
text=True,
timeout=5,
)
except (
subprocess.CalledProcessError,
FileNotFoundError,
subprocess.TimeoutExpired,
):
return True # fail-open if detection fails
output_lines = result.stdout.strip().split("\n")
# filter rows that start with "GPU0" (avoid header row)
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
if not gpu0_rows:
can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank)
except AssertionError as exc:
# some sort of logic error in indexing processes, assume p2p is fine for now
LOG.warning(exc)
return True
# consider P2P supported if any OK is present in the GPU0 row
return "OK" in gpu0_rows[-1]
return can_p2p
return True

View File

@@ -641,6 +641,7 @@ def setup_parallelism_envs(cfg):
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`")
os.environ["NCCL_P2P_DISABLE"] = "1"
# TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
if cfg.fsdp or cfg.fsdp_config: